├── .dockerignore ├── .gitattributes ├── .github ├── ISSUE_TEMPLATE │ └── bug_report.md └── workflows │ ├── CI.yml │ └── docs.yml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── Cargo.lock ├── Cargo.toml ├── Dockerfile ├── LICENSE ├── README.md ├── SECURITY.md ├── bench ├── attention.pdf ├── colpali.pdf └── mistral.pdf ├── docs ├── CNAME ├── assets │ ├── 128x128.png │ ├── 128x128@2x.png │ ├── 32x32.png │ ├── Square107x107Logo.png │ ├── Square142x142Logo.png │ ├── Square150x150Logo.png │ ├── Square284x284Logo.png │ ├── Square30x30Logo.png │ ├── Square310x310Logo.png │ ├── Square44x44Logo.png │ ├── Square71x71Logo.png │ ├── Square89x89Logo.png │ ├── StoreLogo.png │ ├── demo.gif.REMOVED.git-id │ ├── icon.icns │ ├── icon.ico │ └── icon.png ├── blog │ ├── .authors.yml │ ├── index.md │ └── posts │ │ ├── Gemini-Agent.md │ │ ├── Journey.md │ │ ├── ReleaseNotes5-5.md │ │ ├── Smolagent.md │ │ ├── adapter-development-guide.md │ │ ├── colpali.md │ │ ├── embed-anything.md │ │ ├── image.png │ │ ├── lumo.md │ │ ├── pycon_talk.md │ │ ├── tracing_in_lumo.md │ │ ├── v0.5.md │ │ └── vector-streaming.md ├── guides │ ├── adapters.md │ ├── colpali.md │ ├── images.md │ ├── ocr.md │ ├── onnx_models.md │ └── semantic.md ├── index.md ├── references.md └── roadmap │ ├── contribution.md │ └── roadmap.md ├── examples ├── adapters │ ├── chromadb_adaptor.py │ ├── elastic.py │ ├── lancedb_adapter.py │ ├── milvus_db.py │ ├── opensearch.py │ ├── pinecone_db.py │ ├── qdrant.py │ └── weaviate_db.py ├── audio.py ├── clip.py ├── cohere_pdf.py ├── colbert.py ├── colpali.py ├── hybridsearch.py ├── images │ ├── blue_jacket.jpg │ ├── coat.jpg │ ├── faux.webp │ ├── shirt.jpg │ └── white shirt.jpg ├── late_chunking.py ├── model2vec.py ├── notebooks │ ├── EmbedAnything_X_smolagent.ipynb │ ├── SingleStoreDB.ipynb │ └── colpali.ipynb ├── onnx_models.py ├── reranker.py ├── semantic_chunking.py ├── splade.py ├── text.py ├── text_ocr.py └── web.py ├── mkdocs.yml ├── processors ├── Cargo.toml ├── README.md └── src │ ├── docx_processor.rs │ ├── html_processor.rs │ ├── lib.rs │ ├── markdown_processor.rs │ ├── pdf │ ├── mod.rs │ ├── pdf_processor.rs │ └── tesseract │ │ ├── command.rs │ │ ├── error.rs │ │ ├── input.rs │ │ ├── mod.rs │ │ ├── output_boxes.rs │ │ ├── output_config_parameters.rs │ │ ├── output_data.rs │ │ └── parse_line_util.rs │ ├── processor.rs │ └── txt_processor.rs ├── pyproject.toml ├── python ├── Cargo.lock ├── Cargo.toml ├── python │ └── embed_anything │ │ ├── __init__.py │ │ ├── _embed_anything.pyi │ │ ├── libiomp5.so │ │ ├── libiomp5md.dll │ │ ├── py.typed │ │ └── vectordb.py └── src │ ├── config.rs │ ├── lib.rs │ └── models │ ├── colbert.rs │ ├── colpali.rs │ ├── mod.rs │ └── reranker.rs ├── rust ├── Cargo.toml ├── examples │ ├── audio.rs │ ├── bert.rs │ ├── clip.rs │ ├── cloud.rs │ ├── cohere_pdf.rs │ ├── colbert.rs │ ├── colpali.rs │ ├── late_chunking.rs │ ├── model2vec.rs │ ├── ort_models.rs │ ├── reranker.rs │ ├── splade.rs │ └── web_embed.rs └── src │ ├── chunkers │ ├── cumulative.rs │ ├── mod.rs │ └── statistical.rs │ ├── config.rs │ ├── embeddings │ ├── cloud │ │ ├── cohere.rs │ │ ├── mod.rs │ │ └── openai.rs │ ├── embed.rs │ ├── local │ │ ├── bert.rs │ │ ├── clip.rs │ │ ├── colbert.rs │ │ ├── colpali.rs │ │ ├── colpali_ort.rs │ │ ├── jina.rs │ │ ├── mod.rs │ │ ├── model2vec.rs │ │ ├── model_info.rs │ │ ├── modernbert.rs │ │ ├── ort_bert.rs │ │ ├── ort_jina.rs │ │ ├── pooling.rs │ │ └── text_embedding.rs │ ├── mod.rs │ └── utils.rs │ ├── file_loader.rs │ ├── file_processor │ ├── audio │ │ ├── audio_processor.rs │ │ ├── melfilters.bytes │ │ ├── melfilters128.bytes │ │ ├── mod.rs │ │ └── pcm_decode.rs │ └── mod.rs │ ├── lib.rs │ ├── models │ ├── bert.rs │ ├── clip │ │ ├── mod.rs │ │ ├── text_model.rs │ │ └── vision_model.rs │ ├── colpali.rs │ ├── gemma.rs │ ├── jina_bert.rs │ ├── mod.rs │ ├── modernbert.rs │ ├── paligemma.rs │ ├── siglip.rs │ └── with_tracing.rs │ ├── reranker │ ├── mod.rs │ └── model.rs │ └── text_loader.rs ├── test_files ├── audio │ ├── samples_hp0.wav │ └── samples_jfk.wav ├── bank.txt ├── clip │ ├── cat1.jpg │ ├── cat2.jpeg │ ├── dog1.jpg │ ├── dog2.jpeg │ └── monkey1.jpg ├── colpali.pdf ├── linear.pdf ├── test.docx ├── test.html ├── test.md ├── test.pdf └── test.txt └── tests └── model_tests ├── conftest.py ├── test_adapter.py ├── test_audio.py ├── test_bert.py ├── test_clip.py ├── test_colpali.py ├── test_jina.py └── test_openai.py /.dockerignore: -------------------------------------------------------------------------------- 1 | target -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | * text=auto eol=lf 2 | *.{cmd,[cC][mM][dD]} text eol=crlf 3 | *.{bat,[bB][aA][tT]} text eol=crlf -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | - Version [e.g. 22] 30 | 31 | **Smartphone (please complete the following information):** 32 | - Device: [e.g. iPhone6] 33 | - OS: [e.g. iOS8.1] 34 | - Browser [e.g. stock browser, safari] 35 | - Version [e.g. 22] 36 | 37 | **Additional context** 38 | Add any other context about the problem here. 39 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: Docs 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | tags: 8 | - '*' 9 | workflow_dispatch: 10 | 11 | permissions: 12 | contents: write 13 | 14 | jobs: 15 | deploy: 16 | runs-on: ${{ matrix.platform.runner }} 17 | strategy: 18 | matrix: 19 | platform: 20 | - runner: ubuntu-latest 21 | target: x86_64 22 | 23 | steps: 24 | - name: Checkout code 25 | uses: actions/checkout@v4 26 | - name: Configure Git Credentials 27 | run: | 28 | git config user.name github-actions[bot] 29 | git config user.email 41898282+github-actions[bot]@users.noreply.github.com 30 | 31 | - name: Set up Python 32 | uses: actions/setup-python@v5 33 | with: 34 | python-version: '3.x' # Choose the Python version you need 35 | - name: Deploy Docs 36 | run: | 37 | pip install mkdocs mkdocstrings[python] mkdocs-material griffe==0.49.0 38 | mkdocs build 39 | mkdocs gh-deploy --force 40 | 41 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | .pytest_cache/ 6 | *.py[cod] 7 | 8 | # C extensions 9 | embed*.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | .venv/ 14 | env/ 15 | bin/ 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | eggs/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | include/ 25 | man/ 26 | venv/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | 31 | # Installer logs 32 | pip-log.txt 33 | pip-delete-this-directory.txt 34 | pip-selfcheck.json 35 | test.ipynb 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .cache 42 | nosetests.xml 43 | coverage.xml 44 | 45 | # Translations 46 | *.mo 47 | 48 | # Mr Developer 49 | .mr.developer.cfg 50 | .project 51 | .pydevproject 52 | 53 | # Rope 54 | .ropeproject 55 | 56 | # Django stuff: 57 | *.log 58 | *.pot 59 | 60 | .DS_Store 61 | 62 | # Sphinx documentation 63 | docs/_build/ 64 | 65 | # PyCharm 66 | .idea/ 67 | 68 | # VSCode 69 | .vscode/ 70 | 71 | # Pyenv 72 | .python-version 73 | 74 | 75 | site 76 | 77 | *.ipynb 78 | 79 | tmp/ 80 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ## 🚀 Getting Started 2 | To get started, check the [Issues Section] for tasks labeled "Good First Issue" or "Help Needed". These issues are perfect for new contributors or those looking to make a valuable impact quickly. 3 | 4 | If you find an issue you want to tackle: 5 | 6 | Comment on the issue to let us know you’d like to work on it. 7 | Wait for confirmation—an admin will assign the issue to you. 8 | 💻 Setting Up Your Development Environment 9 | To start working on the project, follow these steps: 10 | 11 | 1. Fork the Repository: Begin by forking the repository from the dev branch. We do not allow direct contributions to the main branch. 12 | 2. Clone Your Fork: After forking, clone the repository to your local machine. 13 | 3. Create a New Branch: For each contribution, create a new branch following the naming convention: feature/your-feature-name or bugfix/your-bug-name. 14 | 15 | ## 🛠️ Contributing Guidelines 16 | 🔍 Reporting Bugs 17 | If you find a bug, here’s how to report it effectively: 18 | 19 | Title: Use a clear and descriptive title, with appropriate labels. 20 | Description: Provide a detailed description of the issue, including: 21 | Steps to reproduce the problem. 22 | Expected and actual behavior. 23 | 24 | Any relevant logs, screenshots, or additional context. 25 | Submit the Bug Report: Open a new issue in the [Issues Section] and include all the details. This helps us understand and resolve the problem faster. 26 | 27 | ## 🐍 Contributing to Python Code 28 | If you're contributing to the Python codebase, follow these steps: 29 | 30 | 1. Create an Independent File: Write your code in a new file within the python folder. 31 | 2. Build with Maturin: After writing your code, use maturin build to build the package. 32 | 3. Import and Call the Function: 33 | 4. Use the following import syntax: 34 | from embed_anything. import * 35 | 5. Then, call the function using: 36 | from embed_anything import 37 | Feel free to open an issue if you encounter any problems during the process. 38 | 39 | 🧩 Contributing to Adapters 40 | To contribute to adapters, follow these guidelines: 41 | 42 | 1. Implement Adapter Class: Create an Adapter class that supports the create, add, and delete operations for your specific use case. 43 | 2. Check Existing Adapters: Use the existing Pinecone and Weaviate adapters as references to maintain consistency in structure and functionality. 44 | 3. Testing: Ensure your adapter is tested thoroughly before submitting a pull request. 45 | 46 | 47 | ### 🔄 Submitting a Pull Request 48 | Once your contribution is ready: 49 | 50 | Push Your Branch: Push your branch to your forked repository. 51 | 52 | Submit a Pull Request (PR): Open a PR from your branch to the dev branch of the main repository. Ensure your PR includes: 53 | 54 | 1. A clear description of the changes. 55 | 2. Any relevant issue numbers (e.g., "Closes #123"). 56 | 3. Wait for Review: A maintainer will review your PR. Please be responsive to any feedback or requested changes. -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | members = [ 3 | "processors", 4 | "rust", 5 | "python", 6 | ] 7 | # Python package needs to be built by maturin. 8 | exclude = ["python"] 9 | resolver = "2" 10 | 11 | [workspace.package] 12 | edition = "2021" 13 | license = "Apache-2.0" 14 | description = "Embed anything at lightning speed" 15 | repository = "https://github.com/StarlightSearch/EmbedAnything" 16 | authors = ["Akshay Ballal "] 17 | exclude = ["test_files/*", "python", "*.py", "pyproject.toml", "examples/images/*", "mkdocs.yml", "docs/*", "tests/*", ".github", "Dockerfile", "docs"] 18 | version = "0.6.0" 19 | 20 | [workspace.dependencies] 21 | pdf-extract = "0.9.0" 22 | candle-nn = { version = "0.8.3" } 23 | candle-transformers = { version = "0.8.3" } 24 | candle-core = { version = "0.8.3" } 25 | candle-flash-attn = { version = "0.8.3" } 26 | 27 | strum = "0.27.0" 28 | strum_macros = "0.27.0" 29 | 30 | 31 | 32 | [profile.dev] 33 | rpath = true 34 | 35 | [profile.release] 36 | rpath = true 37 | 38 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM lukemathwalker/cargo-chef:latest-rust-1 AS chef 2 | WORKDIR /app 3 | 4 | FROM chef AS planner 5 | COPY . . 6 | RUN cargo chef prepare --recipe-path recipe.json 7 | 8 | FROM chef AS builder 9 | COPY --from=planner /app/recipe.json recipe.json 10 | # Build dependencies - this is the caching Docker layer! 11 | RUN cargo chef cook --release --recipe-path recipe.json 12 | # Build application 13 | 14 | RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null \ 15 | && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list \ 16 | && apt-get update \ 17 | && apt-get install -y intel-oneapi-mkl-devel \ 18 | && export LD_LIBRARY_PATH="/opt/intel/oneapi/compiler/2024.2/lib:$LD_LIBRARY_PATH" 19 | 20 | RUN apt-get install libssl-dev pkg-config python3-full python3-pip -y 21 | RUN pip3 install maturin[patchelf] --break-system-packages 22 | 23 | COPY . . 24 | RUN maturin build --release --features mkl,extension-module 25 | 26 | FROM python:3.11-slim 27 | 28 | WORKDIR /app 29 | 30 | COPY . . 31 | 32 | COPY --from=builder /app/target/wheels . 33 | 34 | RUN pip install *.whl 35 | 36 | RUN pip install numpy pillow pytest onnxruntime pymupdf numpy 37 | 38 | # Set the library path, initializing it if not already set 39 | ENV LD_LIBRARY_PATH="/usr/lib" 40 | 41 | ENV ORT_DYLIB_PATH="/usr/local/lib/python3.11/site-packages/onnxruntime/capi/libonnxruntime.so.1.19.2" 42 | 43 | CMD ["pytest tests"] -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # Security Policy 2 | 3 | ## Supported Versions 4 | 5 | Use this section to tell people about which versions of your project are 6 | currently being supported with security updates. 7 | 8 | | Version | Supported | 9 | | ------- | ------------------ | 10 | | 0.2.x | :white_check_mark: | 11 | 12 | ## Reporting a Vulnerability 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /bench/attention.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StarlightSearch/EmbedAnything/82b406a8afbec9507320102f24f4ce169fade329/bench/attention.pdf -------------------------------------------------------------------------------- /bench/colpali.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StarlightSearch/EmbedAnything/82b406a8afbec9507320102f24f4ce169fade329/bench/colpali.pdf -------------------------------------------------------------------------------- /bench/mistral.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StarlightSearch/EmbedAnything/82b406a8afbec9507320102f24f4ce169fade329/bench/mistral.pdf -------------------------------------------------------------------------------- /docs/CNAME: -------------------------------------------------------------------------------- 1 | embed-anything.com 2 | -------------------------------------------------------------------------------- /docs/assets/128x128.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StarlightSearch/EmbedAnything/82b406a8afbec9507320102f24f4ce169fade329/docs/assets/128x128.png -------------------------------------------------------------------------------- /docs/assets/128x128@2x.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StarlightSearch/EmbedAnything/82b406a8afbec9507320102f24f4ce169fade329/docs/assets/128x128@2x.png -------------------------------------------------------------------------------- /docs/assets/32x32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StarlightSearch/EmbedAnything/82b406a8afbec9507320102f24f4ce169fade329/docs/assets/32x32.png -------------------------------------------------------------------------------- /docs/assets/Square107x107Logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StarlightSearch/EmbedAnything/82b406a8afbec9507320102f24f4ce169fade329/docs/assets/Square107x107Logo.png -------------------------------------------------------------------------------- /docs/assets/Square142x142Logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StarlightSearch/EmbedAnything/82b406a8afbec9507320102f24f4ce169fade329/docs/assets/Square142x142Logo.png -------------------------------------------------------------------------------- /docs/assets/Square150x150Logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StarlightSearch/EmbedAnything/82b406a8afbec9507320102f24f4ce169fade329/docs/assets/Square150x150Logo.png -------------------------------------------------------------------------------- /docs/assets/Square284x284Logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StarlightSearch/EmbedAnything/82b406a8afbec9507320102f24f4ce169fade329/docs/assets/Square284x284Logo.png -------------------------------------------------------------------------------- /docs/assets/Square30x30Logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StarlightSearch/EmbedAnything/82b406a8afbec9507320102f24f4ce169fade329/docs/assets/Square30x30Logo.png -------------------------------------------------------------------------------- /docs/assets/Square310x310Logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StarlightSearch/EmbedAnything/82b406a8afbec9507320102f24f4ce169fade329/docs/assets/Square310x310Logo.png -------------------------------------------------------------------------------- /docs/assets/Square44x44Logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StarlightSearch/EmbedAnything/82b406a8afbec9507320102f24f4ce169fade329/docs/assets/Square44x44Logo.png -------------------------------------------------------------------------------- /docs/assets/Square71x71Logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StarlightSearch/EmbedAnything/82b406a8afbec9507320102f24f4ce169fade329/docs/assets/Square71x71Logo.png -------------------------------------------------------------------------------- /docs/assets/Square89x89Logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StarlightSearch/EmbedAnything/82b406a8afbec9507320102f24f4ce169fade329/docs/assets/Square89x89Logo.png -------------------------------------------------------------------------------- /docs/assets/StoreLogo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StarlightSearch/EmbedAnything/82b406a8afbec9507320102f24f4ce169fade329/docs/assets/StoreLogo.png -------------------------------------------------------------------------------- /docs/assets/demo.gif.REMOVED.git-id: -------------------------------------------------------------------------------- 1 | 6621126c28656a9e39d08c30e4312d92e6a8cef7 -------------------------------------------------------------------------------- /docs/assets/icon.icns: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StarlightSearch/EmbedAnything/82b406a8afbec9507320102f24f4ce169fade329/docs/assets/icon.icns -------------------------------------------------------------------------------- /docs/assets/icon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StarlightSearch/EmbedAnything/82b406a8afbec9507320102f24f4ce169fade329/docs/assets/icon.ico -------------------------------------------------------------------------------- /docs/assets/icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StarlightSearch/EmbedAnything/82b406a8afbec9507320102f24f4ce169fade329/docs/assets/icon.png -------------------------------------------------------------------------------- /docs/blog/.authors.yml: -------------------------------------------------------------------------------- 1 | authors: 2 | akshay: 3 | name: Akshay Ballal 4 | description: Creator of EmbedAnything 5 | avatar: https://pbs.twimg.com/profile_images/1660187462357127168/6dV9SpLi_400x400.jpg 6 | sonam: 7 | name: Sonam Pankaj 8 | description: Creator of EmbedAnything 9 | avatar: https://pbs.twimg.com/profile_images/1798985783292125184/L6YQmg1Q_400x400.jpg -------------------------------------------------------------------------------- /docs/blog/index.md: -------------------------------------------------------------------------------- 1 | # 📰 All Posts 2 | -------------------------------------------------------------------------------- /docs/blog/posts/ReleaseNotes5-5.md: -------------------------------------------------------------------------------- 1 | --- 2 | draft: false 3 | date: 2025-05-25 4 | authors: 5 | - sonam 6 | slug: release-notes-6 7 | title: Release Notes 6.0 8 | --- 9 | 10 | Super Excited to share the latest development in our library, which essentially giving you more embedding choices -- Cohere and siglip, new chunking method-- late chunking and more crates that facilitates amazing modality and maintainability for our rust codebase, --processor crate. so let's dive in. 11 | 12 | 13 | 14 | ## Late Chunking 15 | 16 | The new 0.5.6 version adds Late Chunking to EmbedAnything, a technique introduced by Jina AI and Weaviate. 17 | Here's how we've implemented Late Chunking in EA: 18 | 19 | 𝗕𝗮𝘁𝗰𝗵 𝗮𝘀 𝗖𝗵𝘂𝗻𝗸 𝗚𝗿𝗼𝘂𝗽: In EmbedAnything, with late chunking enabled, the batch size determines the number of neighboring chunks that will be processed together. 20 | 21 | 𝗝𝗼𝗶𝗻𝘁 𝗘𝗺𝗯𝗲𝗱𝗱𝗶𝗻𝗴: The grouped chunks are fed into the embedding model as a single, larger input. This allows the model to capture relationships and dependencies between adjacent chunks. 22 | 23 | 𝗘𝗺𝗯𝗲𝗱𝗱𝗶𝗻𝗴 𝗦𝗽𝗹𝗶𝘁: After embedding, the combined output is divided back into the embeddings for the original, individual chunks. 24 | 25 | 𝗠𝗲𝗮𝗻 𝗣𝗼𝗼𝗹𝗶𝗻𝗴 (𝗽𝗲𝗿 𝗖𝗵𝘂𝗻𝗸): Mean pooling is then applied to each individual chunk's embedding, incorporating the contextual information learned during the joint embedding phase. 26 | 27 | 𝐾𝑒𝑦 𝐵𝑒𝑛𝑒𝑓𝑖𝑡𝑠: 28 | 29 | 𝗖𝗼𝗻𝘁𝗲𝘅𝘁-𝗔𝘄𝗮𝗿𝗲 𝗘𝗺𝗯𝗲𝗱𝗱𝗶𝗻𝗴𝘀: By embedding neighboring chunks together, we capture crucial contextual information that would be lost with independent chunking. 30 | 31 | 𝗢𝗽𝘁𝗶𝗺𝗶𝘇𝗲𝗱 𝗥𝗲𝘁𝗿𝗶𝗲𝘃𝗮𝗹 𝗣𝗲𝗿𝗳𝗼𝗿𝗺𝗮𝗻𝗰𝗲: Expect a significant improvement in the accuracy and relevance of your search results. 32 | 33 | ```python 34 | model:EmbeddingModel = EmbeddingModel.from_pretrained_onnx( 35 | WhichModel.Jina, hf_model_id="jinaai/jina-embeddings-v2-small-en", path_in_repo="model.onnx" 36 | ) 37 | config = TextEmbedConfig( 38 | chunk_size=1000, 39 | batch_size=8, 40 | splitting_strategy="sentence", 41 | late_chunking=True, 42 | ) 43 | 44 | # Embed a single file 45 | data: list[EmbedData] = model.embed_file("test_files/attention.pdf", config=config) 46 | ``` 47 | 48 | 49 | ## Cohere Embed 4: 50 | 51 | 🧊 Single embedding per document, even for multimodal inputs 52 | 📚 Handles up to 128K tokens – perfect for long-form business documents 53 | 🗃️ Supports compressed vector formats (int8, binary) for real-world scalability 54 | 🌐 Multilingual across 100+ languages 55 | 56 | The catch? It’s not open-source—and even if it were, the model would be quite hefty to run locally. But if you’re already using cloud-based embeddings like OpenAI’s, Embed v4 is worth testing. 57 | 58 | ```python 59 | # Initialize the model once 60 | model: EmbeddingModel = EmbeddingModel.from_pretrained_cloud( 61 | WhichModel.CohereVision, model_id="embed-v4.0" 62 | ) 63 | 64 | ``` 65 | 66 | ## SigLIP 67 | 68 | We already had Clip support but many of you asked for siglip support. It out performs clip for zero shot classification for smaller batch. It also has better memory efficinecy. 69 | 70 | ```python 71 | # Load the model. 72 | model = embed_anything.EmbeddingModel.from_pretrained_hf( 73 | embed_anything.WhichModel.Clip, 74 | model_id="google/siglip-base-patch16-224", 75 | ) 76 | ``` 77 | 78 | ## Processor Crate: 79 | 80 | This crate contains various "processors" that accepts files and produces a chunked, metadata-rich document description. This is especially helpful for retrieval-augmented generation! 81 | 82 | We have also received some additional cool feature requests on GitHub, which we would like to implement. If you want to help out please check out EmbedAnything on GitHub. We would love to have a contribution. 🚀 83 | 84 | 85 | 86 | -------------------------------------------------------------------------------- /docs/blog/posts/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StarlightSearch/EmbedAnything/82b406a8afbec9507320102f24f4ce169fade329/docs/blog/posts/image.png -------------------------------------------------------------------------------- /docs/blog/posts/tracing_in_lumo.md: -------------------------------------------------------------------------------- 1 | --- 2 | draft: false 3 | date: 2025-05-01 4 | authors: 5 | - sonam 6 | slug: observability 7 | title: Easy Observability to our agentic framework; LUMO 8 | --- 9 | 10 | In the rapidly evolving landscape of AI agents, particularly those employing Large Language Models (LLMs), observability and tracing have emerged as fundamental requirements rather than optional features. As agents become more complex and handle increasingly critical tasks, understanding their inner workings, debugging issues, and establishing accountability becomes paramount. 11 | 12 | 13 | ## Understanding Observability in AI Agents 14 | 15 | Observability refers to the ability to understand the internal state of a system through its external outputs. In AI agents, comprehensive observability encompasses: 16 | 17 | 1. **Decision Visibility**: Transparency into how and why an agent made specific decisions 18 | 2. **State Tracking**: Monitoring the agent's internal state as it evolves throughout task execution 19 | 3. **Resource Utilization**: Measuring computational resources, API calls, and external interactions 20 | 4. **Performance Metrics**: Capturing response times, completion rates, and quality indicators 21 | 22 | ## The Multi-Faceted Value of Tracing and Observability 23 | 24 | ### 1. Debugging and Troubleshooting 25 | 26 | AI agents, especially those leveraging LLMs, operate with inherent complexity and sometimes unpredictability. Without proper observability: 27 | 28 | - **Silent Failures** become common, where agents fail without clear indications of what went wrong 29 | - **Root Cause Analysis** becomes nearly impossible as there's no trace of the execution path 30 | 31 | ### 2. Performance Optimization 32 | 33 | Observability provides crucial insights for optimizing agent performance: 34 | 35 | - **Caching Opportunities**: Recognize repeated patterns that could benefit from caching 36 | 37 | ### 3. Security and Compliance 38 | 39 | As agents gain more capabilities and autonomy, security becomes increasingly critical: 40 | 41 | - **Audit Trails**: Maintain comprehensive logs of all agent actions for compliance and security reviews 42 | - **Prompt Injection Detection**: Identify potential attempts to manipulate the agent's behavior 43 | 44 | ### 4. User Trust and Transparency 45 | 46 | For end-users working with AI agents, transparency builds trust: 47 | 48 | - **Action Justification**: Provide clear explanations for why the agent took specific actions 49 | - **Confidence Indicators**: Show reliability metrics for different types of responses 50 | 51 | ### 5. Continuous Improvement 52 | 53 | Observability creates a foundation for systematic improvement: 54 | 55 | - **Pattern Recognition**: Identify standard failure modes or suboptimal behaviors 56 | - **A/B Testing**: Compare different agent configurations with detailed performance metrics 57 | 58 | ## Implementing Effective Observability in Lumo 59 | 60 | For Tracing and Observability 61 | 62 | ``` 63 | vim ~/.bashrc 64 | ``` 65 | Add the three keys from Langfuse: 66 | 67 | ``` 68 | LANGFUSE_PUBLIC_KEY_DEV=your-dev-public-key 69 | LANGFUSE_SECRET_KEY_DEV=your-dev-secret-key 70 | LANGFUSE_HOST_DEV=http://localhost:3000 # Or your dev Langfuse instance URL 71 | ``` 72 | 73 | Start lumo-cli or lumo server then press: 74 | 75 | ``` 76 | CTRL + C 77 | ``` 78 | And it’s added to the dashboard 79 | 80 | ![image.png](attachment:2e738a1a-0d90-4eca-80a6-23539ac38d43:image.png) 81 | 82 | ## Conclusion 83 | 84 | Observability and tracing are no longer optional components for serious AI agent implementations. They form the foundation for reliable, secure, and continuously improving systems. As agents take on more responsibility and autonomy, the ability to observe, understand, and explain their behavior becomes not just a technical requirement but an ethical imperative. 85 | 86 | Organizations building or deploying AI agents should invest early in robust observability infrastructure, treating it as a core capability rather than an afterthought. The insights gained will improve current systems and also inform the development of better, more trustworthy agents in the future. -------------------------------------------------------------------------------- /docs/blog/posts/v0.5.md: -------------------------------------------------------------------------------- 1 | --- 2 | draft: false 3 | date: 2025-01-01 4 | authors: 5 | - sonam 6 | slug: modernBERT 7 | title: version 0.5 8 | --- 9 | 10 | We are thrilled to share that EmbedAnything version 0.5 is out now and comprise of insane development like support for ModernBert and ReRanker models. Along with Ingestion pipeline support for DocX, and HTML let’s get in details. 11 | 12 | The best of all have been support for late-interaction model, both ColPali and ColBERT on onnx. 13 | 14 | 15 | 16 | 17 | 1. **ModernBert** Support: Well it made quite a splash, and we were obliged to add it, in the fastest inference engine, embedanything. In addition to being faster and more accurate, ModernBERT also increases context length to 8k tokens (compared to just 512 for most encoders), and is the first encoder-only model that includes a large amount of code in its training data. 18 | 2. **ColPali- Onnx** :  Running the ColPali model directly on a local machine might not always be feasible. To address this, we developed a **quantized version of ColPali**. Find it on our hugging face, link [here](https://huggingface.co/starlight-ai/colpali-v1.2-merged-onnx). You could also run it both on Candle and on ONNX. 19 | 3. **ColBERT**: ColBERT is a *fast* and *accurate* retrieval model, enabling scalable BERT-based search over large text collections in tens of milliseconds. 20 | 4. **ReRankers:** EmbedAnything recently contributed for the support of reranking models to Candle so as to add it in our own library. It can support any kind of reranking models. Precision meets performance! Use reranking models to refine your retrieval results for even greater accuracy. 21 | 5. **Jina V3:** Also contributed to V3 models, for Jina can seamlessly integrate any V3 model. 22 | 6. **𝗗𝗢𝗖𝗫 𝗣𝗿𝗼𝗰𝗲𝘀𝘀𝗶𝗻𝗴** 23 | 24 | Effortlessly extract text from .docx files and convert it into embeddings. Simplify your document workflows like never before! 25 | 26 | 7. **𝗛𝗧𝗠𝗟 𝗣𝗿𝗼𝗰𝗲𝘀𝘀𝗶𝗻𝗴:** 27 | 28 | Parsing and embedding HTML documents just got easier! 29 | 30 | ✅ Extract rich metadata with embeddings 31 | ✅ Handle code blocks separately for better context 32 | 33 | Supercharge your documentation retrieval with these advanced capabilities. -------------------------------------------------------------------------------- /docs/guides/adapters.md: -------------------------------------------------------------------------------- 1 | # Using Vector Database Adapters 2 | 3 | ## Using Elasticsearch 4 | 5 | To use Elasticsearch, you need to install the `elasticsearch` package. 6 | ```bash 7 | pip install elasticsearch 8 | ``` 9 | 10 | ``` python 11 | --8<-- "examples/adapters/elastic.py" 12 | ``` 13 | 14 | ## Using Weaviate 15 | 16 | To use Weaviate, you need to install the `weaviate-client` package. 17 | 18 | ```bash 19 | pip install weaviate-client 20 | ``` 21 | 22 | ``` python 23 | --8<-- "examples/adapters/weaviate_db.py" 24 | ``` 25 | 26 | ## Using Pinecone 27 | 28 | To use Pinecone, you need to install the `pinecone` package. 29 | 30 | ```bash 31 | pip install pinecone 32 | ``` 33 | 34 | ``` python 35 | --8<-- "examples/adapters/pinecone_db.py" 36 | ``` 37 | 38 | ## Using Qdrant 39 | 40 | To use [Qdrant](https://qdrant.tech/), you need to install the `qdrant-client` package. 41 | 42 | ```bash 43 | pip install qdrant-client 44 | ``` 45 | 46 | ``` python 47 | --8<-- "examples/adapters/qdrant.py" 48 | ``` 49 | 50 | ## Using Milvus 51 | 52 | To use [Milvus](https://milvus.io/), you need to install the `pymilvus` package. 53 | 54 | ```bash 55 | pip install pymilvus 56 | ``` 57 | 58 | ``` python 59 | --8<-- "examples/adapters/milvus_db.py" 60 | ``` 61 | -------------------------------------------------------------------------------- /docs/guides/colpali.md: -------------------------------------------------------------------------------- 1 | # Using Colpali 2 | This example leverages the ColpaliModel from the EmbedAnything library, specifically designed for high-performance document embedding and semantic search. Colpali supports both native and ONNX formats, making it versatile for fast, efficient model loading. 3 | 4 | ``` python 5 | --8<-- "examples/colpali.py" 6 | ``` -------------------------------------------------------------------------------- /docs/guides/images.md: -------------------------------------------------------------------------------- 1 | # Searching Images 2 | 3 | This example shows how to use the EmbeddingModel from EmbedAnything to perform semantic image search within a directory, leveraging the CLIP model for accurate, language-guided matching. 4 | 5 | ``` python 6 | --8<-- "examples/clip.py" 7 | ``` 8 | 9 | ## Supported Models 10 | 11 | EmbedAnything supports the following models for image search: 12 | 13 | - openai/clip-vit-base-patch32 14 | - openai/clip-vit-base-patch16 15 | - openai/clip-vit-large-patch14-336 16 | - openai/clip-vit-large-patch14 -------------------------------------------------------------------------------- /docs/guides/ocr.md: -------------------------------------------------------------------------------- 1 | # Use PDFs that need OCR 2 | 3 | Embed Anything can be used to embed scanned documents using OCR. This is useful for tasks such as document search and retrieval. You can set `use_ocr=True` in the `TextEmbedConfig` to enable OCR. But this requires `tesseract` and `poppler` to be installed. 4 | 5 | You can install `tesseract` and `poppler` using the following commands: 6 | 7 | ## Install Tesseract and Poppler 8 | 9 | ### Windows 10 | 11 | For Tesseract, download the installer from [here](https://github.com/UB-Mannheim/tesseract/wiki) and install it. 12 | 13 | For Poppler, download the installer from [here](https://github.com/oschwartz10612/poppler-windows?tab=readme-ov-file) and install it. 14 | 15 | ### MacOS 16 | 17 | For Tesseract, you can install it using Homebrew. 18 | 19 | ``` bash 20 | brew install tesseract 21 | ``` 22 | 23 | For Poppler, you can install it using Homebrew. 24 | 25 | ``` bash 26 | brew install poppler 27 | ``` 28 | 29 | ### Linux 30 | 31 | For Tesseract, you can install it using the package manager for your Linux distribution. For example, on Ubuntu, you can install it using: 32 | 33 | ``` bash 34 | sudo apt install tesseract-ocr 35 | sudo apt install libtesseract-dev 36 | 37 | ``` 38 | 39 | For Poppler, you can install it using the package manager for your Linux distribution. For example, on Ubuntu, you can install it using: 40 | 41 | ``` bash 42 | sudo apt install poppler-utils 43 | ``` 44 | 45 | For more information, refer to the [Tesseract installation guide](https://tesseract-ocr.github.io/tessdoc/Installation.html). 46 | 47 | ## Example Usage 48 | 49 | ``` python 50 | --8<-- "examples/text_ocr.py" 51 | ``` 52 | -------------------------------------------------------------------------------- /docs/guides/onnx_models.md: -------------------------------------------------------------------------------- 1 | # Using ONNX Models 2 | 3 | ## Supported Models 4 | 5 | | Enum Variant | Description | 6 | |----------------------------------|--------------------------------------------------| 7 | | `AllMiniLML6V2` | sentence-transformers/all-MiniLM-L6-v2 | 8 | | `AllMiniLML6V2Q` | Quantized sentence-transformers/all-MiniLM-L6-v2 | 9 | | `AllMiniLML12V2` | sentence-transformers/all-MiniLM-L12-v2 | 10 | | `AllMiniLML12V2Q` | Quantized sentence-transformers/all-MiniLM-L12-v2| 11 | | `ModernBERTBase` | nomic-ai/modernbert-embed-base | 12 | | `ModernBERTLarge` | nomic-ai/modernbert-embed-large | 13 | | `BGEBaseENV15` | BAAI/bge-base-en-v1.5 | 14 | | `BGEBaseENV15Q` | Quantized BAAI/bge-base-en-v1.5 | 15 | | `BGELargeENV15` | BAAI/bge-large-en-v1.5 | 16 | | `BGELargeENV15Q` | Quantized BAAI/bge-large-en-v1.5 | 17 | | `BGESmallENV15` | BAAI/bge-small-en-v1.5 - Default | 18 | | `BGESmallENV15Q` | Quantized BAAI/bge-small-en-v1.5 | 19 | | `NomicEmbedTextV1` | nomic-ai/nomic-embed-text-v1 | 20 | | `NomicEmbedTextV15` | nomic-ai/nomic-embed-text-v1.5 | 21 | | `NomicEmbedTextV15Q` | Quantized nomic-ai/nomic-embed-text-v1.5 | 22 | | `ParaphraseMLMiniLML12V2` | sentence-transformers/paraphrase-MiniLM-L6-v2 | 23 | | `ParaphraseMLMiniLML12V2Q` | Quantized sentence-transformers/paraphrase-MiniLM-L6-v2 | 24 | | `ParaphraseMLMpnetBaseV2` | sentence-transformers/paraphrase-mpnet-base-v2 | 25 | | `BGESmallZHV15` | BAAI/bge-small-zh-v1.5 | 26 | | `MultilingualE5Small` | intfloat/multilingual-e5-small | 27 | | `MultilingualE5Base` | intfloat/multilingual-e5-base | 28 | | `MultilingualE5Large` | intfloat/multilingual-e5-large | 29 | | `MxbaiEmbedLargeV1` | mixedbread-ai/mxbai-embed-large-v1 | 30 | | `MxbaiEmbedLargeV1Q` | Quantized mixedbread-ai/mxbai-embed-large-v1 | 31 | | `GTEBaseENV15` | Alibaba-NLP/gte-base-en-v1.5 | 32 | | `GTEBaseENV15Q` | Quantized Alibaba-NLP/gte-base-en-v1.5 | 33 | | `GTELargeENV15` | Alibaba-NLP/gte-large-en-v1.5 | 34 | | `GTELargeENV15Q` | Quantized Alibaba-NLP/gte-large-en-v1.5 | 35 | | `JINAV2SMALLEN` | jinaai/jina-embeddings-v2-small-en | 36 | | `JINAV2BASEEN` | jinaai/jina-embeddings-v2-base-en | 37 | | `JINAV3` | jinaai/jina-embeddings-v3 | 38 | 39 | ## Example Usage 40 | 41 | ``` python 42 | --8<-- "examples/onnx_models.py" 43 | ``` 44 | -------------------------------------------------------------------------------- /docs/guides/semantic.md: -------------------------------------------------------------------------------- 1 | # Using Semantic Chunking 2 | 3 | Semantic encoding is essential for applications where maintaining the logical flow and meaning of text is critical, such as in document retrieval, question answering, or summarization. This approach ensures that embeddings capture the full intent and nuance of the original content, enhancing downstream model performance. 4 | 5 | ``` python 6 | --8<-- "examples/semantic_chunking.py" 7 | ``` -------------------------------------------------------------------------------- /docs/references.md: -------------------------------------------------------------------------------- 1 | # 📚 References 2 | 3 | ::: python.python.embed_anything 4 | handler: python 5 | -------------------------------------------------------------------------------- /docs/roadmap/contribution.md: -------------------------------------------------------------------------------- 1 | # Contribution Guidelines 2 | 3 | ## 🚀 Getting Started 4 | To get started, check the [Issues Section] for tasks labeled "Good First Issue" or "Help Needed". These issues are perfect for new contributors or those looking to make a valuable impact quickly. 5 | 6 | If you find an issue you want to tackle: 7 | 8 | Comment on the issue to let us know you’d like to work on it. 9 | Wait for confirmation—an admin will assign the issue to you. 10 | 💻 Setting Up Your Development Environment 11 | To start working on the project, follow these steps: 12 | 13 | 1. Fork the Repository: Begin by forking the repository from the dev branch. We do not allow direct contributions to the main branch. 14 | 2. Clone Your Fork: After forking, clone the repository to your local machine. 15 | 3. Create a New Branch: For each contribution, create a new branch following the naming convention: feature/your-feature-name or bugfix/your-bug-name. 16 | 17 | ## 🛠️ Contributing Guidelines 18 | 🔍 Reporting Bugs 19 | If you find a bug, here’s how to report it effectively: 20 | 21 | Title: Use a clear and descriptive title, with appropriate labels. 22 | Description: Provide a detailed description of the issue, including: 23 | Steps to reproduce the problem. 24 | Expected and actual behavior. 25 | 26 | Any relevant logs, screenshots, or additional context. 27 | Submit the Bug Report: Open a new issue in the [Issues Section] and include all the details. This helps us understand and resolve the problem faster. 28 | 29 | ## 🐍 Contributing to Python Code 30 | If you're contributing to the Python codebase, follow these steps: 31 | 32 | 1. Create an Independent File: Write your code in a new file within the python folder. 33 | 2. Build with Maturin: After writing your code, use maturin build to build the package. 34 | 3. Import and Call the Function: 35 | 4. Use the following import syntax: 36 | from embed_anything. import * 37 | 5. Then, call the function using: 38 | from embed_anything import 39 | Feel free to open an issue if you encounter any problems during the process. 40 | 41 | 🧩 Contributing to Adapters 42 | To contribute to adapters, follow these guidelines: 43 | 44 | 1. Implement Adapter Class: Create an Adapter class that supports the create, add, and delete operations for your specific use case. 45 | 2. Check Existing Adapters: Use the existing Pinecone and Weaviate adapters as references to maintain consistency in structure and functionality. 46 | 3. Testing: Ensure your adapter is tested thoroughly before submitting a pull request. 47 | 48 | 49 | ### 🔄 Submitting a Pull Request 50 | Once your contribution is ready: 51 | 52 | Push Your Branch: Push your branch to your forked repository. 53 | 54 | Submit a Pull Request (PR): Open a PR from your branch to the dev branch of the main repository. Ensure your PR includes: 55 | 56 | 1. A clear description of the changes. 57 | 2. Any relevant issue numbers (e.g., "Closes #123"). 58 | 3. Wait for Review: A maintainer will review your PR. Please be responsive to any feedback or requested changes. -------------------------------------------------------------------------------- /docs/roadmap/roadmap.md: -------------------------------------------------------------------------------- 1 | 2 | # 🏎️ RoadMap 3 | 4 | ## Accomplishments 5 | 6 | One of the aims of EmbedAnything is to allow AI engineers to easily use state of the art embedding models on typical files and documents. A lot has already been accomplished here and these are the formats that we support right now and a few more have to be done.
7 | 8 | ### 🖼️ Modalities and Source 9 | 10 | We’re excited to share that we've expanded our platform to support multiple modalities, including: 11 | 12 | - [x] Audio files 13 | 14 | - [x] Markdowns 15 | 16 | - [x] Websites 17 | 18 | - [x] Images 19 | 20 | - [ ] Videos 21 | 22 | - [ ] Graph 23 | 24 | This gives you the flexibility to work with various data types all in one place! 🌐
25 | 26 | ### 💜 Product 27 | We’ve rolled out some major updates in version 0.3 to improve both functionality and performance. Here’s what’s new: 28 | 29 | - Semantic Chunking: Optimized chunking strategy for better Retrieval-Augmented Generation (RAG) workflows. 30 | 31 | - Streaming for Efficient Indexing: We’ve introduced streaming for memory-efficient indexing in vector databases. Want to know more? Check out our article on this feature here: https://www.analyticsvidhya.com/blog/2024/09/vector-streaming/ 32 | 33 | - Zero-Shot Applications: Explore our zero-shot application demos to see the power of these updates in action. 34 | 35 | - Intuitive Functions: Version 0.3 includes a complete refactor for more intuitive functions, making the platform easier to use. 36 | 37 | - Chunkwise Streaming: Instead of file-by-file streaming, we now support chunkwise streaming, allowing for more flexible and efficient data processing. 38 | 39 | Check out the latest release : and see how these features can supercharge your GenerativeAI pipeline! ✨ 40 | 41 | ## 🚀Coming Soon
42 | 43 | ### ⚙️ Performance 44 | We've received quite a few questions about why we're using Candle, so here's a quick explanation: 45 | 46 | One of the main reasons is that Candle doesn't require any specific ONNX format models, which means it can work seamlessly with any Hugging Face model. This flexibility has been a key factor for us. However, we also recognize that we’ve been compromising a bit on speed in favor of that flexibility. 47 | 48 | What’s Next? 49 | To address this, we’re excited to announce that we’re introducing Candle-ONNX along with our previous framework on hugging-face , 50 | 51 | ➡️ Support for GGUF models
52 | - Significantly faster performance
53 | - Stay tuned for these exciting updates! 🚀
54 | 55 | 56 | ### 🫐Embeddings: 57 | 58 | We had multimodality from day one for our infrastructure. We have already included it for websites, images and audios but we want to expand it further to. 59 | 60 | ☑️Graph embedding -- build deepwalks embeddings depth first and word to vec
61 | ☑️Video Embedding
62 | ☑️ Yolo Clip
63 | 64 | 65 | ### 🌊Expansion to other Vector Adapters 66 | 67 | We currently support a wide range of vector databases for streaming embeddings, including: 68 | 69 | - Elastic: thanks to amazing and active Elastic team for the contribution
70 | - Weaviate
71 | - Pinecone
72 | - Qdrant
73 | - Milvus
74 | 75 | But we're not stopping there! We're actively working to expand this list. 76 | 77 | Want to Contribute? 78 | If you’d like to add support for your favorite vector database, we’d love to have your help! Check out our contribution.md for guidelines, or feel free to reach out directly starlight-search@proton.me. Let's build something amazing together! 💡 79 | -------------------------------------------------------------------------------- /examples/adapters/chromadb_adaptor.py: -------------------------------------------------------------------------------- 1 | import embed_anything 2 | import os 3 | from typing import Dict, List, Optional 4 | from embed_anything import EmbedData 5 | from embed_anything.vectordb import Adapter 6 | from uuid import uuid4 7 | import chromadb 8 | 9 | class ChromaAdapter(Adapter): 10 | def __init__(self, db_path: str, embedding_dimension: int): 11 | self.db_path = db_path 12 | self.dimension = embedding_dimension 13 | 14 | os.makedirs(db_path, exist_ok=True) 15 | # Initialize ChromaDB with persistence path 16 | self.client = chromadb.PersistentClient(path=self.db_path) 17 | 18 | def create_index(self, table_name: str): 19 | # Get or create collection with provided name 20 | self.collection = self.client.get_or_create_collection( 21 | name=table_name 22 | ) 23 | 24 | def convert(self, embeddings: List[EmbedData]) -> Dict[str, List]: 25 | # Format data for ChromaDB's expected structure 26 | ids = [] 27 | documents = [] 28 | embeddings_list = [] 29 | metadatas = [] 30 | 31 | for embedding in embeddings: 32 | id = str(uuid4()) 33 | ids.append(id) 34 | documents.append(embedding.text) 35 | embeddings_list.append(embedding.embedding) 36 | 37 | metadata = { 38 | "file_name": embedding.metadata["file_name"], 39 | "modified": embedding.metadata["modified"], 40 | "created": embedding.metadata["created"] 41 | } 42 | metadatas.append(metadata) 43 | 44 | return { 45 | "ids": ids, 46 | "documents": documents, 47 | "embeddings": embeddings_list, 48 | "metadatas": metadatas 49 | } 50 | 51 | def delete_index(self, table_name: str): 52 | # Remove collection if it exists 53 | self.client.delete_collection(name=table_name) 54 | 55 | 56 | def upsert(self, data: List[EmbedData]): 57 | # Add documents and embeddings to collection 58 | converted_data = self.convert(data) 59 | 60 | self.collection.add( 61 | ids=converted_data["ids"], 62 | documents=converted_data["documents"], 63 | embeddings=converted_data["embeddings"], 64 | metadatas=converted_data["metadatas"] 65 | ) 66 | 67 | 68 | def main(): 69 | # Initialize adapter and model 70 | chroma_adapter = ChromaAdapter(db_path="tmp/chromadb", embedding_dimension=384) 71 | 72 | model = embed_anything.EmbeddingModel.from_pretrained_hf( 73 | embed_anything.WhichModel.Bert, 74 | model_id="sentence-transformers/all-MiniLM-L12-v2" 75 | ) 76 | 77 | # Create collection and embed documents 78 | chroma_adapter.create_index("docs") 79 | 80 | data = embed_anything.embed_file( 81 | "MoE.pdf", 82 | embedder=model, 83 | adapter=chroma_adapter, 84 | ) 85 | 86 | # Example search 87 | query_embedding = embed_anything.embed_query(['what is mistral'], embedder=model)[0].embedding 88 | 89 | results = chroma_adapter.collection.query( 90 | query_embeddings=[query_embedding], 91 | n_results=5 92 | ) 93 | 94 | for doc in results["documents"][0]: 95 | print(doc) 96 | 97 | if __name__ == "__main__": 98 | main() -------------------------------------------------------------------------------- /examples/adapters/elastic.py: -------------------------------------------------------------------------------- 1 | import embed_anything 2 | import os 3 | 4 | from typing import Dict, List 5 | from embed_anything import EmbedData 6 | from embed_anything.vectordb import Adapter 7 | from embed_anything import EmbedData, EmbeddingModel, TextEmbedConfig, WhichModel 8 | 9 | from elasticsearch import Elasticsearch 10 | from elasticsearch.helpers import bulk 11 | 12 | 13 | class ElasticsearchAdapter(Adapter): 14 | 15 | def __init__(self, api_key: str, cloud_id: str, index_name: str = "anything"): 16 | self.es = Elasticsearch(cloud_id=cloud_id, api_key=api_key) 17 | self.index_name = index_name 18 | 19 | def create_index( 20 | self, dimension: int, metric: str, mappings={}, settings={}, **kwargs 21 | ): 22 | 23 | if "index_name" in kwargs: 24 | self.index_name = kwargs["index_name"] 25 | 26 | self.es.indices.create( 27 | index=self.index_name, mappings=mappings, settings=settings 28 | ) 29 | 30 | def convert(self, embeddings: List[List[EmbedData]]) -> List[Dict]: 31 | data = [] 32 | for embedding in embeddings: 33 | data.append( 34 | { 35 | "text": embedding.text, 36 | "embeddings": embedding.embedding, 37 | "metadata": { 38 | "file_name": embedding.metadata["file_name"], 39 | "modified": embedding.metadata["modified"], 40 | "created": embedding.metadata["created"], 41 | }, 42 | } 43 | ) 44 | return data 45 | 46 | def delete_index(self, index_name: str): 47 | self.es.indices.delete(index=index_name) 48 | 49 | def gendata(self, data): 50 | for doc in data: 51 | yield doc 52 | 53 | def upsert(self, data: List[Dict]): 54 | data = self.convert(data) 55 | bulk(client=self.es, index="anything", actions=self.gendata(data)) 56 | 57 | 58 | index_name = "anything" 59 | elastic_api_key = os.environ.get("ELASTIC_API_KEY") 60 | elastic_cloud_id = os.environ.get("ELASTIC_CLOUD_ID") 61 | 62 | # Initialize the ElasticsearchAdapter Class 63 | elasticsearch_adapter = ElasticsearchAdapter( 64 | api_key=elastic_api_key, 65 | cloud_id=elastic_cloud_id, 66 | index_name=index_name, 67 | ) 68 | 69 | # Prase PDF and insert documents into Elasticsearch. 70 | model = EmbeddingModel.from_pretrained_hf( 71 | WhichModel.Bert, model_id="sentence-transformers/all-MiniLM-L12-v2" 72 | ) 73 | 74 | 75 | data = embed_anything.embed_file( 76 | "/home/sonamAI/projects/EmbedAnything/test_files/attention.pdf", 77 | embedder=model, 78 | adapter=elasticsearch_adapter 79 | ) 80 | 81 | # Create an Index with explicit mappings. 82 | mappings = { 83 | "properties": { 84 | "embeddings": {"type": "dense_vector", "dims": 384}, 85 | "text": {"type": "text"}, 86 | } 87 | } 88 | settings = {} 89 | 90 | elasticsearch_adapter.create_index( 91 | dimension=384, 92 | metric="cosine", 93 | mappings=mappings, 94 | settings=settings, 95 | ) 96 | 97 | # Delete an Index 98 | elasticsearch_adapter.delete_index(index_name=index_name) 99 | -------------------------------------------------------------------------------- /examples/adapters/lancedb_adapter.py: -------------------------------------------------------------------------------- 1 | import embed_anything 2 | import os 3 | from typing import Dict, List 4 | from embed_anything import EmbedData 5 | from embed_anything.vectordb import Adapter 6 | from uuid import uuid4 7 | import lancedb 8 | 9 | class LanceAdapter(Adapter): 10 | def __init__(self, db_path: str, embedding_dimension: int): 11 | 12 | import pyarrow as pa # For schema definition 13 | 14 | self.db_path = db_path 15 | self.connection = lancedb.connect(self.db_path) 16 | self.dimension = embedding_dimension 17 | 18 | # Define schema using pyarrow 19 | self.schema = pa.schema([ 20 | pa.field("embeddings", pa.list_(pa.float32(), self.dimension)), 21 | pa.field("text", pa.string()), 22 | pa.field("file_name", pa.string()), 23 | pa.field("modified", pa.string()), 24 | pa.field("created", pa.string()) 25 | ]) 26 | 27 | def create_index(self, table_name: str): 28 | self.table_name = table_name 29 | self.connection = lancedb.connect(self.db_path) 30 | self.table = self.connection.create_table(table_name, schema=self.schema) 31 | 32 | 33 | def convert(self, embeddings: List[List[EmbedData]]) -> List[Dict]: 34 | data = [] 35 | for embedding in embeddings: 36 | 37 | data.append( 38 | { 39 | "text": embedding.text, 40 | "embeddings": embedding.embedding, 41 | "file_name": embedding.metadata["file_name"], 42 | "modified": embedding.metadata["modified"], 43 | "created": embedding.metadata["created"], 44 | } 45 | ) 46 | return data 47 | 48 | def delete_index(self, table_name: str): 49 | self.connection.drop_table(table_name) 50 | 51 | def upsert(self, data: EmbedData): 52 | self.table.add(self.convert(data)) 53 | 54 | 55 | def main(): 56 | # Initialize adapter 57 | lance_adapter = LanceAdapter(db_path="tmp/lancedb", embedding_dimension=384) 58 | 59 | # Initialize model 60 | model = embed_anything.EmbeddingModel.from_pretrained_hf( 61 | embed_anything.WhichModel.Bert, 62 | model_id="sentence-transformers/all-MiniLM-L12-v2" 63 | ) 64 | 65 | # Create index and embed data 66 | if "docs" in lance_adapter.connection.table_names(): 67 | lance_adapter.delete_index("docs") 68 | lance_adapter.create_index("docs") 69 | 70 | data = embed_anything.embed_file( 71 | "test_files/attention.pdf", 72 | embedder=model, 73 | adapter=lance_adapter, 74 | ) 75 | 76 | # Example search 77 | query_vec = embed_anything.embed_query(['attention'], embedder=model)[0].embedding 78 | docs = lance_adapter.table.search(query_vec).limit(5).to_pandas()["text"] 79 | print(docs[2]) 80 | 81 | if __name__ == "__main__": 82 | main() 83 | -------------------------------------------------------------------------------- /examples/adapters/pinecone_db.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Dict, List 3 | import uuid 4 | import embed_anything 5 | import os 6 | 7 | from embed_anything.vectordb import Adapter 8 | from pinecone import Pinecone, ServerlessSpec 9 | 10 | from embed_anything import EmbedData, EmbeddingModel, WhichModel, TextEmbedConfig 11 | 12 | 13 | class PineconeAdapter(Adapter): 14 | """ 15 | Adapter class for interacting with Pinecone, a vector database service. 16 | """ 17 | 18 | def __init__(self, api_key: str): 19 | """ 20 | Initializes a new instance of the PineconeAdapter class. 21 | 22 | Args: 23 | api_key (str): The API key for accessing the Pinecone service. 24 | """ 25 | super().__init__(api_key) 26 | self.pc = Pinecone(api_key=self.api_key) 27 | self.index_name = None 28 | 29 | def create_index( 30 | self, 31 | dimension: int, 32 | metric: str = "cosine", 33 | index_name: str = "anything", 34 | spec=ServerlessSpec(cloud="aws", region="us-east-1"), 35 | ): 36 | """ 37 | Creates a new index in Pinecone. 38 | 39 | Args: 40 | dimension (int): The dimensionality of the embeddings. 41 | metric (str, optional): The distance metric to use for similarity search. Defaults to "cosine". 42 | index_name (str, optional): The name of the index. Defaults to "anything". 43 | spec (ServerlessSpec, optional): The serverless specification for the index. Defaults to AWS in us-east-1 region. 44 | """ 45 | self.index_name = index_name 46 | self.pc.create_index( 47 | name=index_name, dimension=dimension, metric=metric, spec=spec 48 | ) 49 | 50 | def delete_index(self, index_name: str): 51 | """ 52 | Deletes an existing index from Pinecone. 53 | 54 | Args: 55 | index_name (str): The name of the index to delete. 56 | """ 57 | self.pc.delete_index(name=index_name) 58 | 59 | def convert(self, embeddings: List[EmbedData]) -> List[Dict]: 60 | """ 61 | Converts a list of embeddings into the required format for upserting into Pinecone. 62 | 63 | Args: 64 | embeddings (List[EmbedData]): The list of embeddings to convert. 65 | 66 | Returns: 67 | List[Dict]: The converted data in the required format for upserting into Pinecone. 68 | """ 69 | data_emb = [] 70 | 71 | for embedding in embeddings: 72 | data_emb.append( 73 | { 74 | "id": str(uuid.uuid4()), 75 | "values": embedding.embedding, 76 | "metadata": { 77 | "text": embedding.text, 78 | "file": re.split( 79 | r"/|\\", embedding.metadata.get("file_name", "") 80 | )[-1], 81 | }, 82 | } 83 | ) 84 | return data_emb 85 | 86 | def upsert(self, data: List[Dict]): 87 | """ 88 | Upserts data into the specified index in Pinecone. 89 | 90 | Args: 91 | data (List[Dict]): The data to upsert into Pinecone. 92 | 93 | Raises: 94 | ValueError: If the index has not been created before upserting data. 95 | """ 96 | data = self.convert(data) 97 | if not self.index_name: 98 | raise ValueError("Index must be created before upserting data") 99 | self.pc.Index(name=self.index_name).upsert(data) 100 | 101 | 102 | # Initialize the PineconeEmbedder class 103 | api_key = os.environ.get("PINECONE_API_KEY") 104 | index_name = "anything" 105 | pinecone_adapter = PineconeAdapter(api_key) 106 | 107 | try: 108 | pinecone_adapter.delete_index("anything") 109 | except: 110 | pass 111 | 112 | # Initialize the PineconeEmbedder class 113 | 114 | pinecone_adapter.create_index(dimension=512, metric="cosine") 115 | 116 | 117 | model = EmbeddingModel.from_pretrained_hf( 118 | WhichModel.Bert, model_id="sentence-transformers/all-MiniLM-L12-v2" 119 | ) 120 | 121 | 122 | data = embed_anything.embed_file( 123 | "/home/sonamAI/projects/EmbedAnything/test_files/attention.pdf", 124 | embedder=model, 125 | adapter=pinecone_adapter, 126 | ) 127 | 128 | 129 | 130 | data = embed_anything.embed_image_directory( 131 | "test_files", 132 | embedder=model, 133 | adapter=pinecone_adapter 134 | ) 135 | print(data) 136 | -------------------------------------------------------------------------------- /examples/adapters/qdrant.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from typing import List, Dict 3 | from qdrant_client import QdrantClient 4 | from qdrant_client.models import ( 5 | Distance, 6 | VectorParams, 7 | PointStruct, 8 | ) 9 | import embed_anything 10 | from embed_anything import EmbedData, EmbeddingModel, WhichModel 11 | from embed_anything.vectordb import Adapter 12 | 13 | 14 | class QdrantAdapter(Adapter): 15 | """ 16 | Adapter class for interacting with [Qdrant](https://qdrant.tech/). 17 | """ 18 | 19 | def __init__(self, client: QdrantClient): 20 | """ 21 | Initializes a new instance of the QdrantAdapter class. 22 | 23 | Args: 24 | client : An instance of qdrant_client.QdrantClient 25 | """ 26 | self.client = client 27 | 28 | def create_index( 29 | self, 30 | dimension: int, 31 | metric: Distance = Distance.COSINE, 32 | index_name: str = "embed-anything", 33 | **kwargs, 34 | ): 35 | self.collection_name = index_name 36 | 37 | if not self.client.collection_exists(index_name): 38 | self.client.create_collection( 39 | collection_name=index_name, 40 | vectors_config=VectorParams(size=dimension, distance=metric), 41 | ) 42 | 43 | def delete_index(self, index_name: str): 44 | self.client.delete_collection(collection_name=index_name) 45 | 46 | def convert(self, embeddings: List[EmbedData]) -> List[PointStruct]: 47 | points = [] 48 | for embedding in embeddings: 49 | points.append( 50 | PointStruct( 51 | id=str(uuid.uuid4()), 52 | vector=embedding.embedding, 53 | payload={ 54 | "text": embedding.text, 55 | "file_name": embedding.metadata["file_name"], 56 | "modified": embedding.metadata["modified"], 57 | "created": embedding.metadata["created"], 58 | }, 59 | ) 60 | ) 61 | return points 62 | 63 | def upsert(self, data: List[Dict]): 64 | points = self.convert(data) 65 | self.client.upsert( 66 | collection_name=self.collection_name, 67 | points=points, 68 | ) 69 | 70 | 71 | def main(): 72 | adapter = QdrantAdapter(QdrantClient(location=":memory:")) 73 | adapter.create_index(dimension=384) 74 | 75 | model = EmbeddingModel.from_pretrained_hf( 76 | WhichModel.Bert, model_id="sentence-transformers/all-MiniLM-L12-v2" 77 | ) 78 | 79 | embed_anything.embed_file( 80 | "test_files/attention.pdf", 81 | embedder=model, 82 | adapter=adapter, 83 | ) 84 | 85 | 86 | if __name__ == "__main__": 87 | main() 88 | -------------------------------------------------------------------------------- /examples/adapters/weaviate_db.py: -------------------------------------------------------------------------------- 1 | import weaviate, os 2 | import weaviate.classes as wvc 3 | from tqdm.auto import tqdm 4 | import embed_anything 5 | from embed_anything import EmbedData, EmbeddingModel, TextEmbedConfig, WhichModel 6 | from embed_anything.vectordb import Adapter 7 | import textwrap 8 | 9 | ## Weaviate Adapter 10 | 11 | from typing import List 12 | 13 | 14 | class WeaviateAdapter(Adapter): 15 | def __init__(self, api_key, url): 16 | super().__init__(api_key) 17 | self.client = weaviate.connect_to_weaviate_cloud( 18 | cluster_url=url, auth_credentials=wvc.init.Auth.api_key(api_key) 19 | ) 20 | if self.client.is_ready(): 21 | print("Weaviate is ready") 22 | 23 | def create_index(self, index_name: str): 24 | self.index_name = index_name 25 | self.collection = self.client.collections.create( 26 | index_name, vectorizer_config=wvc.config.Configure.Vectorizer.none() 27 | ) 28 | return self.collection 29 | 30 | def convert(self, embeddings: List[EmbedData]): 31 | data = [] 32 | for embedding in embeddings: 33 | property = embedding.metadata 34 | property["text"] = embedding.text 35 | data.append( 36 | wvc.data.DataObject(properties=property, vector=embedding.embedding) 37 | ) 38 | return data 39 | 40 | def upsert(self, data_): 41 | data_ = self.convert(data_) 42 | self.client.collections.get(self.index_name).data.insert_many(data_) 43 | 44 | def delete_index(self, index_name: str): 45 | self.client.collections.delete(index_name) 46 | 47 | 48 | URL = "URL" 49 | API_KEY = "API_KEY" 50 | weaviate_adapter = WeaviateAdapter(API_KEY, URL) 51 | 52 | 53 | # create index 54 | index_name = "Test_index" 55 | if index_name in weaviate_adapter.client.collections.list_all(): 56 | weaviate_adapter.delete_index(index_name) 57 | weaviate_adapter.create_index("Test_index") 58 | 59 | 60 | model = EmbeddingModel.from_pretrained_hf( 61 | WhichModel.Bert, model_id="sentence-transformers/all-MiniLM-L12-v2" 62 | ) 63 | 64 | 65 | data = embed_anything.embed_file( 66 | "/home/sonamAI/projects/EmbedAnything/test_files/attention.pdf", 67 | embedder=model, 68 | adapter=weaviate_adapter, 69 | ) 70 | 71 | query_vector = embed_anything.embed_query(["What is attention"], embedder=model)[ 72 | 0 73 | ].embedding 74 | 75 | 76 | response = weaviate_adapter.collection.query.near_vector( 77 | near_vector=query_vector, 78 | limit=2, 79 | return_metadata=wvc.query.MetadataQuery(certainty=True), 80 | ) 81 | 82 | for i in range(len(response.objects)): 83 | print(response.objects[i].properties["text"]) 84 | 85 | 86 | for res in response.objects: 87 | print(textwrap.fill(res.properties["text"], width=120), end="\n\n") 88 | -------------------------------------------------------------------------------- /examples/audio.py: -------------------------------------------------------------------------------- 1 | import embed_anything 2 | from embed_anything import ( 3 | AudioDecoderModel, 4 | EmbeddingModel, 5 | embed_audio_file, 6 | TextEmbedConfig, 7 | ) 8 | import time 9 | 10 | start_time = time.time() 11 | 12 | # choose any whisper or distilwhisper model from https://huggingface.co/distil-whisper or https://huggingface.co/collections/openai/whisper-release-6501bba2cf999715fd953013 13 | audio_decoder = AudioDecoderModel.from_pretrained_hf( 14 | "openai/whisper-tiny.en", revision="main", model_type="tiny-en", quantized=False 15 | ) 16 | 17 | embedder = EmbeddingModel.from_pretrained_hf( 18 | embed_anything.WhichModel.Bert, 19 | model_id="sentence-transformers/all-MiniLM-L6-v2", 20 | revision="main", 21 | ) 22 | 23 | config = TextEmbedConfig(chunk_size=200, batch_size=32) 24 | data = embed_anything.embed_audio_file( 25 | "test_files/audio/samples_hp0.wav", 26 | audio_decoder=audio_decoder, 27 | embedder=embedder, 28 | text_embed_config=config, 29 | ) 30 | print(data[0].metadata) 31 | end_time = time.time() 32 | print("Time taken: ", end_time - start_time) 33 | -------------------------------------------------------------------------------- /examples/clip.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import embed_anything 3 | from embed_anything import EmbedData 4 | import time 5 | 6 | start = time.time() 7 | 8 | # Load the model. 9 | model = embed_anything.EmbeddingModel.from_pretrained_hf( 10 | embed_anything.WhichModel.Clip, 11 | model_id="google/siglip-base-patch16-224", 12 | ) 13 | data: list[EmbedData] = embed_anything.embed_image_directory( 14 | "test_files", embedder=model 15 | ) 16 | 17 | # Convert the embeddings to a numpy array 18 | embeddings = np.array([data.embedding for data in data]) 19 | 20 | # Embed a query 21 | query = ["Photo of a monkey"] 22 | query_embedding = np.array( 23 | embed_anything.embed_query(query, embedder=model)[0].embedding 24 | ) 25 | 26 | # Calculate the similarities between the query embedding and all the embeddings 27 | similarities = np.dot(embeddings, query_embedding) 28 | 29 | # Find the index of the most similar embedding 30 | max_index = np.argmax(similarities) 31 | 32 | print("Descending order of similarity: ") 33 | indices = np.argsort(similarities)[::-1] 34 | for idx in indices: 35 | print(data[idx].text) 36 | 37 | print("----------- ") 38 | 39 | # Print the most similar image 40 | print("Most similar image: ", data[max_index].text) 41 | end = time.time() 42 | print("Time taken: ", end - start) 43 | -------------------------------------------------------------------------------- /examples/cohere_pdf.py: -------------------------------------------------------------------------------- 1 | from embed_anything import EmbeddingModel, TextEmbedConfig, WhichModel 2 | import numpy as np 3 | from pathlib import Path 4 | from tabulate import tabulate 5 | from embed_anything import EmbedData 6 | from pdf2image import convert_from_path 7 | 8 | 9 | # Initialize the model once 10 | model: EmbeddingModel = EmbeddingModel.from_pretrained_cloud( 11 | WhichModel.CohereVision, model_id="embed-v4.0" 12 | ) 13 | 14 | 15 | # Get all PDF files in the directory 16 | directory = Path("test_files") 17 | files = directory.glob("*.pdf") 18 | # files = [Path("test_files/attention.pdf")] 19 | 20 | file_embed_data: list[EmbedData] = [] 21 | for file in files: 22 | try: 23 | embedding: list[EmbedData] = model.embed_file( 24 | str(file), TextEmbedConfig(batch_size=8) 25 | ) 26 | file_embed_data.extend(embedding) 27 | except Exception as e: 28 | print(f"Error embedding file {file}: {e}") 29 | 30 | # Define the query 31 | query = "What are the Bleu score results for the attention paper?" 32 | 33 | # Scoring 34 | file_embeddings = np.array([e.embedding for e in file_embed_data]) 35 | query_embedding = model.embed_query([query]) 36 | query_embeddings = np.array([e.embedding for e in query_embedding]) 37 | print(file_embeddings.shape) 38 | print(query_embeddings.shape) 39 | 40 | 41 | scores = np.dot(query_embeddings, file_embeddings.T).squeeze() 42 | 43 | # Get top pages 44 | top_pages = np.argsort(scores)[-5:][::-1].tolist() # Convert to list 45 | 46 | print(top_pages) 47 | # Extract file names and page numbers 48 | table = [ 49 | [ 50 | file_embed_data[int(page)].metadata["file_path"], 51 | file_embed_data[int(page)].metadata["page_number"], 52 | ] 53 | for page in top_pages 54 | ] 55 | 56 | # Print the results in a table 57 | print(tabulate(table, headers=["File Name", "Page Number"], tablefmt="grid")) 58 | 59 | images = [file_embed_data[int(page)].metadata["image"] for page in top_pages] 60 | -------------------------------------------------------------------------------- /examples/colbert.py: -------------------------------------------------------------------------------- 1 | 2 | from embed_anything import ( 3 | embed_query, 4 | ColbertModel 5 | ) 6 | import os 7 | from time import time 8 | import numpy as np 9 | 10 | model:ColbertModel = ColbertModel.from_pretrained_onnx( 11 | hf_model_id="jinaai/jina-colbert-v2", 12 | path_in_repo="onnx/model.onnx", 13 | ) 14 | 15 | # model:ColbertModel = ColbertModel.from_pretrained_onnx( 16 | # hf_model_id="answerdotai/answerai-colbert-small-v1", 17 | # path_in_repo="onnx/model_fp16.onnx", 18 | # ) 19 | 20 | 21 | sentences = [ 22 | "The quick brown fox jumps over the lazy dog", 23 | "The cat is sleeping on the mat", 24 | "The dog is barking at the moon", 25 | "I love pizza", 26 | "I like to have pasta", 27 | "The dog is sitting in the park", 28 | ] 29 | 30 | query = "I like italian food" 31 | 32 | doc_embeddings = np.array([e.embedding for e in model.embed(sentences, is_doc=True)]) 33 | 34 | query_embeddings = np.array([e.embedding for e in model.embed([query], is_doc=False)]) 35 | 36 | print("shape of doc_embedddings", doc_embeddings.shape) 37 | 38 | scores = ( 39 | np.einsum("bnd,csd->bcns", query_embeddings, doc_embeddings) 40 | .max(axis=3) 41 | .sum(axis=2) 42 | .squeeze() 43 | ) 44 | 45 | for i, score in enumerate(scores): 46 | print(f"{sentences[i]}: {score}") 47 | 48 | -------------------------------------------------------------------------------- /examples/colpali.py: -------------------------------------------------------------------------------- 1 | from embed_anything import EmbedData, ColpaliModel 2 | import numpy as np 3 | from tabulate import tabulate 4 | from pathlib import Path 5 | 6 | 7 | # Load the model 8 | # model: ColpaliModel = ColpaliModel.from_pretrained("vidore/colpali-v1.2-merged", None) 9 | 10 | # Load ONNX Model 11 | model: ColpaliModel = ColpaliModel.from_pretrained_onnx( 12 | "akshayballal/colpali-v1.2-merged-onnx", None 13 | ) 14 | 15 | # Get all PDF files in the directory 16 | directory = Path("test_files") 17 | files = list(directory.glob("*.pdf")) 18 | # files = [Path("test_files/attention.pdf")] 19 | 20 | file_embed_data: list[EmbedData] = [] 21 | for file in files: 22 | try: 23 | embedding: list[EmbedData] = model.embed_file(str(file), batch_size=1) 24 | file_embed_data.extend(embedding) 25 | except Exception as e: 26 | print(f"Error embedding file {file}: {e}") 27 | 28 | # Define the query 29 | query = "What are Positional Encodings" 30 | 31 | # Scoring 32 | file_embeddings = np.array([e.embedding for e in file_embed_data]) 33 | query_embedding = model.embed_query(query) 34 | query_embeddings = np.array([e.embedding for e in query_embedding]) 35 | print(file_embeddings.shape) 36 | print(query_embeddings.shape) 37 | 38 | scores = ( 39 | np.einsum("bnd,csd->bcns", query_embeddings, file_embeddings) 40 | .max(axis=3) 41 | .sum(axis=2) 42 | .squeeze() 43 | ) 44 | 45 | # Get top pages 46 | top_pages = np.argsort(scores)[-5:][::-1] 47 | 48 | # Extract file names and page numbers 49 | table = [ 50 | [ 51 | file_embed_data[page].metadata["file_path"], 52 | file_embed_data[page].metadata["page_number"], 53 | ] 54 | for page in top_pages 55 | ] 56 | 57 | # Print the results in a table 58 | print(tabulate(table, headers=["File Name", "Page Number"], tablefmt="grid")) 59 | 60 | images = [file_embed_data[page].metadata["image"] for page in top_pages] 61 | -------------------------------------------------------------------------------- /examples/hybridsearch.py: -------------------------------------------------------------------------------- 1 | from qdrant_client import QdrantClient, models 2 | from tqdm.auto import tqdm 3 | import embed_anything 4 | from embed_anything import EmbedData 5 | from embed_anything.vectordb import Adapter 6 | import uuid 7 | from embed_anything import EmbedData, EmbeddingModel, TextEmbedConfig, WhichModel 8 | 9 | from typing import List 10 | from qdrant_client.models import PointStruct 11 | 12 | 13 | client = QdrantClient( 14 | "cloud.qdrant.io", 15 | api_key="api", 16 | ) 17 | 18 | sentences = [ 19 | "The cat sits outside", 20 | "A man is playing guitar", 21 | "I love pasta", 22 | "The new movie is awesome", 23 | "The cat plays in the garden", 24 | "A woman watches TV", 25 | "The new movie is so great", 26 | "Do you like pizza?", 27 | ] 28 | 29 | 30 | client.create_collection( 31 | collection_name="my-hybrid-collection", 32 | vectors_config={ 33 | "jina": models.VectorParams( 34 | size=768, 35 | distance=models.Distance.COSINE, 36 | ) 37 | }, 38 | sparse_vectors_config={ 39 | "bm42": models.SparseVectorParams( 40 | modifier=models.Modifier.IDF, 41 | ) 42 | }, 43 | ) 44 | 45 | 46 | query_text = ["best programming language for beginners?"] 47 | 48 | 49 | jina_model = EmbeddingModel.from_pretrained_hf( 50 | WhichModel.Jina, model_id="jinaai/jina-embeddings-v2-small-en" 51 | ) 52 | 53 | jina_embedddings = embed_anything.embed_query(sentences, embedder=jina_model) 54 | jina_query = embed_anything.embed_query(query_text, embedder=jina_model)[0] 55 | 56 | 57 | splade_model = EmbeddingModel.from_pretrained_hf( 58 | WhichModel.SparseBert, "prithivida/Splade_PP_en_v1" 59 | ) 60 | jina_embedddings = embed_anything.embed_query(sentences, embedder=jina_model) 61 | 62 | splade_query = embed_anything.embed_query(query_text, embedder=splade_model) 63 | 64 | client.query_points( 65 | collection_name="my-hybrid-collection", 66 | prefetch=[ 67 | models.Prefetch( 68 | query=jina_query, # <-- dense vector 69 | limit=10, 70 | ), 71 | models.Prefetch( 72 | query=splade_query, # <-- dense vector 73 | limit=10, 74 | ), 75 | ], 76 | query=models.FusionQuery(fusion=models.Fusion.RRF), 77 | ) 78 | -------------------------------------------------------------------------------- /examples/images/blue_jacket.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StarlightSearch/EmbedAnything/82b406a8afbec9507320102f24f4ce169fade329/examples/images/blue_jacket.jpg -------------------------------------------------------------------------------- /examples/images/coat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StarlightSearch/EmbedAnything/82b406a8afbec9507320102f24f4ce169fade329/examples/images/coat.jpg -------------------------------------------------------------------------------- /examples/images/faux.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StarlightSearch/EmbedAnything/82b406a8afbec9507320102f24f4ce169fade329/examples/images/faux.webp -------------------------------------------------------------------------------- /examples/images/shirt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StarlightSearch/EmbedAnything/82b406a8afbec9507320102f24f4ce169fade329/examples/images/shirt.jpg -------------------------------------------------------------------------------- /examples/images/white shirt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StarlightSearch/EmbedAnything/82b406a8afbec9507320102f24f4ce169fade329/examples/images/white shirt.jpg -------------------------------------------------------------------------------- /examples/late_chunking.py: -------------------------------------------------------------------------------- 1 | import heapq 2 | from embed_anything import EmbedData, EmbeddingModel, TextEmbedConfig, WhichModel 3 | 4 | from embed_anything import Dtype, ONNXModel 5 | import numpy as np 6 | 7 | 8 | model:EmbeddingModel = EmbeddingModel.from_pretrained_onnx( 9 | WhichModel.Jina, hf_model_id="jinaai/jina-embeddings-v2-small-en", path_in_repo="model.onnx" 10 | ) 11 | config = TextEmbedConfig( 12 | chunk_size=1000, 13 | batch_size=8, 14 | splitting_strategy="sentence", 15 | late_chunking=True, 16 | ) 17 | 18 | # Embed a single file 19 | data: list[EmbedData] = model.embed_file("test_files/attention.pdf", config=config) 20 | 21 | # Print the embedded data 22 | # for d in data: 23 | # print(d.text) 24 | # print("---" * 20) 25 | 26 | query = "What are positional encodings?" 27 | 28 | query_embedding = np.array(model.embed_query([query])[0].embedding) 29 | 30 | embedding_array = np.array([e.embedding for e in data]) 31 | 32 | similarities = np.matmul(query_embedding, embedding_array.T) 33 | 34 | # get top 5 similarities and its index 35 | top_5_similarities = np.argsort(similarities)[-10:][::-1] 36 | 37 | # Print the top 5 similarities with sentences 38 | for i in top_5_similarities: 39 | print(f"Score: {similarities[i]:.2} | {data[i].text}") 40 | print("---" * 20) 41 | -------------------------------------------------------------------------------- /examples/model2vec.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import embed_anything 4 | from embed_anything import EmbedData, EmbeddingModel, TextEmbedConfig, WhichModel 5 | 6 | # Initialize the model once 7 | model:EmbeddingModel = EmbeddingModel.from_pretrained_hf( 8 | WhichModel.Model2Vec, model_id="minishlab/potion-base-8M" 9 | ) 10 | 11 | 12 | # Example 3: Embedding a File 13 | def embed_file_example(): 14 | # Configure the embedding process 15 | config = TextEmbedConfig( 16 | chunk_size=1000, batch_size=32, buffer_size=64, splitting_strategy="sentence" 17 | ) 18 | 19 | # Embed a single file 20 | data: list[EmbedData] = model.embed_file( 21 | "test_files/bank.txt", config=config 22 | ) 23 | 24 | # Print the embedded data 25 | for d in data: 26 | print(d.text) 27 | print("---" * 20) 28 | 29 | embed_file_example() -------------------------------------------------------------------------------- /examples/onnx_models.py: -------------------------------------------------------------------------------- 1 | import heapq 2 | from embed_anything import ( 3 | EmbeddingModel, 4 | TextEmbedConfig, 5 | WhichModel, 6 | embed_query, 7 | ONNXModel, 8 | Dtype, 9 | ) 10 | import os 11 | from time import time 12 | import numpy as np 13 | 14 | model = EmbeddingModel.from_pretrained_onnx( 15 | WhichModel.Bert, ONNXModel.ModernBERTBase, dtype = Dtype.Q4F16 16 | ) 17 | 18 | # model = EmbeddingModel.from_pretrained_hf( 19 | # WhichModel.Bert, "BAAI/bge-small-en-v1.5" 20 | # ) 21 | 22 | sentences = [ 23 | "The quick brown fox jumps over the lazy dog", 24 | "The cat is sleeping on the mat", 25 | "The dog is barking at the moon", 26 | "I love pizza", 27 | "I like to have pasta", 28 | "The dog is sitting in the park", 29 | ] 30 | 31 | embedddings = embed_query(sentences, embedder=model) 32 | 33 | embed_vector = np.array([e.embedding for e in embedddings]) 34 | 35 | print("shape of embed_vector", embed_vector.shape) 36 | similarities = np.matmul(embed_vector, embed_vector.T) 37 | 38 | # get top 5 similarities and show the two sentences and their similarity scores 39 | # Flatten the upper triangle of the similarity matrix, excluding the diagonal 40 | similarity_scores = [ 41 | (similarities[i, j], i, j) 42 | for i in range(len(sentences)) 43 | for j in range(i + 1, len(sentences)) 44 | ] 45 | 46 | # Get the top 5 similarity scores 47 | top_5_similarities = heapq.nlargest(5, similarity_scores, key=lambda x: x[0]) 48 | 49 | # Print the top 5 similarities with sentences 50 | for score, i, j in top_5_similarities: 51 | print(f"Score: {score:.2} | {sentences[i]} | {sentences[j]}") 52 | 53 | 54 | from embed_anything import EmbeddingModel, WhichModel, embed_query, TextEmbedConfig 55 | import os 56 | import pymupdf 57 | from semantic_text_splitter import TextSplitter 58 | import os 59 | 60 | model = EmbeddingModel.from_pretrained_onnx(WhichModel.Bert, ONNXModel.BGESmallENV15Q) 61 | splitter = TextSplitter(1000) 62 | config = TextEmbedConfig(batch_size=128) 63 | 64 | 65 | def embed_anything(): 66 | # get all pdfs from test_files 67 | 68 | for file in os.listdir("bench"): 69 | text = [] 70 | doc = pymupdf.open("bench/" + file) 71 | 72 | for page in doc: 73 | text.append(page.get_text()) 74 | 75 | text = " ".join(text) 76 | chunks = splitter.chunks(text) 77 | embeddings = embed_query(chunks, model, config) 78 | 79 | 80 | start = time() 81 | embed_anything() 82 | 83 | print(time() - start) 84 | -------------------------------------------------------------------------------- /examples/reranker.py: -------------------------------------------------------------------------------- 1 | from embed_anything import Reranker, Dtype, RerankerResult, DocumentRank 2 | 3 | reranker = Reranker.from_pretrained("jinaai/jina-reranker-v1-turbo-en", dtype=Dtype.F16) 4 | 5 | results: list[RerankerResult] = reranker.rerank(["What is the capital of France?"], ["France is a country in Europe.", "Paris is the capital of France."], 2) 6 | 7 | for result in results: 8 | documents: list[DocumentRank] = result.documents 9 | print(documents) 10 | -------------------------------------------------------------------------------- /examples/semantic_chunking.py: -------------------------------------------------------------------------------- 1 | import embed_anything 2 | from embed_anything import EmbeddingModel, TextEmbedConfig, WhichModel 3 | 4 | model = EmbeddingModel.from_pretrained_hf( 5 | WhichModel.Jina, model_id="jinaai/jina-embeddings-v2-small-en" 6 | ) 7 | 8 | # with semantic encoder 9 | semantic_encoder = EmbeddingModel.from_pretrained_hf( 10 | WhichModel.Jina, model_id="jinaai/jina-embeddings-v2-small-en" 11 | ) 12 | config = TextEmbedConfig( 13 | chunk_size=1000, 14 | batch_size=32, 15 | splitting_strategy="semantic", 16 | semantic_encoder=semantic_encoder, 17 | ) 18 | 19 | data = embed_anything.embed_file("test_files/bank.txt", embedder=model, config=config) 20 | 21 | for d in data: 22 | print(d.text) 23 | print("---" * 20) 24 | -------------------------------------------------------------------------------- /examples/splade.py: -------------------------------------------------------------------------------- 1 | import embed_anything 2 | from embed_anything import EmbedData, EmbeddingModel, ONNXModel, WhichModel, embed_query 3 | from embed_anything.vectordb import Adapter 4 | import os 5 | from time import time 6 | import numpy as np 7 | import heapq 8 | 9 | 10 | model = EmbeddingModel.from_pretrained_hf( 11 | WhichModel.SparseBert, "prithivida/Splade_PP_en_v1" 12 | ) 13 | 14 | ## ONNX model 15 | # model = EmbeddingModel.from_pretrained_onnx( 16 | # WhichModel.SparseBert, 17 | # ONNXModel.SPLADEPPENV2, 18 | # ) 19 | sentences = [ 20 | "The cat sits outside", 21 | "A man is playing guitar", 22 | "I love pasta", 23 | "The new movie is awesome", 24 | "The cat plays in the garden", 25 | "A woman watches TV", 26 | "The new movie is so great", 27 | "Do you like pizza?", 28 | ] 29 | 30 | embedddings = embed_query(sentences, embedder=model) 31 | 32 | embed_vector = np.array([e.embedding for e in embedddings]) 33 | 34 | similarities = np.matmul(embed_vector, embed_vector.T) 35 | 36 | # get top 5 similarities and show the two sentences and their similarity scores 37 | # Flatten the upper triangle of the similarity matrix, excluding the diagonal 38 | similarity_scores = [ 39 | (similarities[i, j], i, j) 40 | for i in range(len(sentences)) 41 | for j in range(i + 1, len(sentences)) 42 | ] 43 | 44 | # Get the top 5 similarity scores 45 | top_5_similarities = heapq.nlargest(5, similarity_scores, key=lambda x: x[0]) 46 | 47 | # Print the top 5 similarities with sentences 48 | for score, i, j in top_5_similarities: 49 | print(f"Score: {score:.2} | {sentences[i]} | {sentences[j]}") 50 | -------------------------------------------------------------------------------- /examples/text.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import embed_anything 4 | from embed_anything import EmbedData, EmbeddingModel, TextEmbedConfig, WhichModel 5 | 6 | # Initialize the model once 7 | model:EmbeddingModel = EmbeddingModel.from_pretrained_hf( 8 | WhichModel.Jina, model_id="jinaai/jina-embeddings-v2-small-en" 9 | ) 10 | 11 | 12 | # Example 1: Embedding a Directory 13 | def embed_directory_example(): 14 | # Configure the embedding process 15 | config = TextEmbedConfig( 16 | chunk_size=1000, batch_size=32, buffer_size=64, splitting_strategy="sentence" 17 | ) 18 | 19 | # Start timing 20 | start = time.time() 21 | 22 | # Embed all files in a directory 23 | data: list[EmbedData] = model.embed_directory( 24 | "bench", config=config 25 | ) 26 | 27 | # End timing 28 | end = time.time() 29 | 30 | print(f"Time taken to embed directory: {end - start} seconds") 31 | 32 | 33 | # Example 2: Embedding a Query 34 | def embed_query_example(): 35 | # Configure the embedding process 36 | config = TextEmbedConfig( 37 | chunk_size=1000, batch_size=32, splitting_strategy="sentence" 38 | ) 39 | 40 | # Embed a query 41 | embeddings: EmbedData = model.embed_query( 42 | ["Hello world my"], config=config 43 | )[0] 44 | 45 | # Print the shape of the embedding 46 | print(np.array(embeddings.embedding).shape) 47 | 48 | # Embed another query and print the result 49 | print( 50 | embed_anything.embed_query( 51 | ["What is the capital of India?"], embedder=model, config=config 52 | ) 53 | ) 54 | 55 | 56 | # Example 3: Embedding a File 57 | def embed_file_example(): 58 | # Configure the embedding process 59 | config = TextEmbedConfig( 60 | chunk_size=1000, batch_size=32, buffer_size=64, splitting_strategy="sentence" 61 | ) 62 | 63 | # Embed a single file 64 | data: list[EmbedData] = model.embed_file( 65 | "test_files/bank.txt", config=config 66 | ) 67 | 68 | # Print the embedded data 69 | for d in data: 70 | print(d.text) 71 | print("---" * 20) 72 | 73 | # Example 4: Embed files in a batch 74 | def embed_files_batch_example(): 75 | 76 | config = TextEmbedConfig(chunk_size = 1000, batch_size = 32, buffer_size = 64) 77 | 78 | data = model.embed_files_batch(["test_files/bank.txt", "test_files/test.pdf"]) 79 | 80 | for d in data: 81 | print(d.text) 82 | print("---" * 20) 83 | 84 | # Call the examples 85 | embed_directory_example() 86 | embed_query_example() 87 | embed_file_example() 88 | embed_files_batch_example() 89 | -------------------------------------------------------------------------------- /examples/text_ocr.py: -------------------------------------------------------------------------------- 1 | # OCR Requires `tesseract` and `poppler` to be installed. 2 | 3 | import time 4 | import embed_anything 5 | from embed_anything import EmbedData, EmbeddingModel, TextEmbedConfig, WhichModel 6 | from time import time 7 | 8 | 9 | model = EmbeddingModel.from_pretrained_hf( 10 | WhichModel.Jina, model_id="jinaai/jina-embeddings-v2-small-en" 11 | ) 12 | 13 | config = TextEmbedConfig( 14 | chunk_size=1000, 15 | batch_size=32, 16 | buffer_size=64, 17 | splitting_strategy="sentence", 18 | use_ocr=True, 19 | ) 20 | 21 | start = time() 22 | 23 | data: list[EmbedData] = embed_anything.embed_file( 24 | "/home/akshay/projects/starlaw/src-server/test_files/court.pdf", # Replace with your file path 25 | embedder=model, 26 | config=config, 27 | ) 28 | end = time() 29 | 30 | for d in data: 31 | print(d.text) 32 | print("---" * 20) 33 | 34 | print(f"Time taken: {end - start} seconds") 35 | -------------------------------------------------------------------------------- /examples/web.py: -------------------------------------------------------------------------------- 1 | import embed_anything 2 | 3 | data = embed_anything.embed_webpage("https://www.akshaymakes.com/", embedder="Bert") 4 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: Starlight Search 2 | site_url: https://starlight-search.com 3 | repo_url: https://github.com/StarlightSearch/EmbedAnything 4 | repo_name: StarlightSearch/EmbedAnything 5 | 6 | theme: 7 | name: "material" 8 | logo: "assets/128x128.png" 9 | favicon: "assets/icon.ico" 10 | icon: 11 | repo: fontawesome/brands/github 12 | 13 | features: 14 | - search.suggest 15 | - search.highlight 16 | - navigation.instant 17 | - navigation.tracking 18 | - navigation.expand 19 | - navigation.sections 20 | - content.code.annotate 21 | - toc.follow 22 | - header.autohide 23 | - announce.dismiss 24 | 25 | palette: 26 | # Palette toggle for light mode 27 | - scheme: default 28 | primary: indigo 29 | toggle: 30 | icon: material/brightness-7 31 | name: Switch to dark mode 32 | 33 | 34 | # Palette toggle for dark mode 35 | - scheme: slate 36 | primary: black 37 | toggle: 38 | icon: material/brightness-4 39 | name: Switch to light mode 40 | 41 | plugins: 42 | - mkdocstrings 43 | - search 44 | - blog: 45 | archive: false 46 | 47 | nav: 48 | - index.md 49 | - references.md 50 | - Blog: 51 | - blog/index.md 52 | - Guides: 53 | - guides/colpali.md 54 | - guides/images.md 55 | - guides/semantic.md 56 | - guides/adapters.md 57 | - guides/onnx_models.md 58 | - guides/ocr.md 59 | - Contribution: 60 | - roadmap/roadmap.md 61 | - roadmap/contribution.md 62 | 63 | markdown_extensions: 64 | - pymdownx.highlight: 65 | anchor_linenums: true 66 | line_spans: __span 67 | pygments_lang_class: true 68 | - pymdownx.inlinehilite 69 | - pymdownx.snippets 70 | - pymdownx.superfences 71 | - def_list 72 | - pymdownx.tasklist: 73 | custom_checkbox: true 74 | 75 | extra: 76 | analytics: 77 | provider: google 78 | property: G-25WL8Y1K9Y 79 | feedback: 80 | title: Was this page helpful? 81 | ratings: 82 | - icon: material/emoticon-happy-outline 83 | name: This page was helpful 84 | data: 1 85 | note: >- 86 | Thanks for your feedback! 87 | - icon: material/emoticon-sad-outline 88 | name: This page could be improved 89 | data: 0 90 | note: >- 91 | Thanks for your feedback! Help us improve this page by 92 | using our feedback form. 93 | social: 94 | - icon: fontawesome/brands/twitter 95 | link: https://x.com/SearchStarlight 96 | - icon: fontawesome/brands/linkedin 97 | link: https://www.linkedin.com/company/mystarlight/ 98 | - icon: fontawesome/brands/discord 99 | link: https://discord.gg/5wX6c4R7zp 100 | - icon: fontawesome/solid/envelope 101 | link: https://starlight-3.kit.com/f15e780cc7 102 | 103 | copyright: Copyright © 2024 Starlight Search 104 | -------------------------------------------------------------------------------- /processors/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "processors" 3 | version.workspace = true 4 | edition.workspace = true 5 | license.workspace = true 6 | description.workspace = true 7 | repository.workspace = true 8 | authors.workspace = true 9 | readme = "README.md" 10 | 11 | [dependencies] 12 | # HTTP Client 13 | reqwest = { version = "0.12.15", default-features = false, features = ["json", "blocking"] } 14 | 15 | # Natural Language Processing 16 | text-splitter = { version= "0.25.1", features=["tokenizers", "markdown"] } 17 | 18 | # Error Handling 19 | anyhow = "1.0.98" 20 | 21 | # HTML processing 22 | htmd = "0.1.6" 23 | 24 | # PDF processing 25 | pdf-extract = {workspace = true} 26 | docx-parser = "0.1.1" 27 | pdf2image = "0.1.3" 28 | image = "0.25.6" 29 | thiserror = "2.0.12" 30 | tempfile = "3.19.1" 31 | 32 | [dev-dependencies] 33 | tempdir = "0.3.7" 34 | 35 | [features] 36 | default = [] 37 | -------------------------------------------------------------------------------- /processors/README.md: -------------------------------------------------------------------------------- 1 | # Processors 2 | 3 | This crate contains various "processors" that accept files/folders/bytes and produced a chunked, metadata-rich document 4 | description. This is especially helpful for retrieval-augmented generation! 5 | -------------------------------------------------------------------------------- /processors/src/docx_processor.rs: -------------------------------------------------------------------------------- 1 | use std::path::Path; 2 | use docx_parser::MarkdownDocument; 3 | use text_splitter::ChunkConfigError; 4 | use crate::markdown_processor::MarkdownProcessor; 5 | use crate::processor::{Document, DocumentProcessor, FileProcessor}; 6 | 7 | /// A struct for processing PDF files. 8 | pub struct DocxProcessor { 9 | markdown_processor: MarkdownProcessor, 10 | } 11 | 12 | impl DocxProcessor { 13 | pub fn new(chunk_size: usize, overlap: usize) -> Result { 14 | let markdown_processor = MarkdownProcessor::new(chunk_size, overlap)?; 15 | Ok(DocxProcessor { 16 | markdown_processor, 17 | }) 18 | } 19 | } 20 | 21 | impl FileProcessor for DocxProcessor { 22 | fn process_file(&self, path: impl AsRef) -> anyhow::Result { 23 | let docs = MarkdownDocument::from_file(path); 24 | let markdown = docs.to_markdown(false); 25 | self.markdown_processor.process_document(&markdown) 26 | } 27 | } 28 | 29 | #[cfg(test)] 30 | mod tests { 31 | use super::*; 32 | #[test] 33 | fn test_extract_text() { 34 | let txt_file = "../test_files/test.docx"; 35 | let processor = DocxProcessor::new(128, 0).unwrap(); 36 | 37 | let text = processor.process_file(&txt_file).unwrap(); 38 | assert!(text.chunks.contains(&"This is a docx file test".to_string())); 39 | } 40 | 41 | // Returns an error if the file path is invalid. 42 | #[test] 43 | #[should_panic(expected = "Error processing file: IO(Os { code: 2, kind: NotFound, message: \"No such file or directory\" })")] 44 | fn test_extract_text_invalid_file_path() { 45 | let invalid_file_path = "this_file_definitely_does_not_exist.docx"; 46 | let processor = DocxProcessor::new(128, 0).unwrap(); 47 | processor.process_file(&invalid_file_path).unwrap(); 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /processors/src/html_processor.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use htmd::{HtmlToMarkdown, HtmlToMarkdownBuilder}; 3 | use text_splitter::ChunkConfigError; 4 | use crate::markdown_processor::MarkdownProcessor; 5 | use crate::processor::{Document, DocumentProcessor}; 6 | 7 | pub struct HtmlDocument { 8 | pub content: String, 9 | pub origin: Option, 10 | } 11 | 12 | /// A Struct for processing HTML files. 13 | pub struct HtmlProcessor { 14 | markdown_processor: MarkdownProcessor, 15 | html_to_markdown: HtmlToMarkdown, 16 | } 17 | 18 | impl HtmlProcessor { 19 | pub fn new(chunk_size: usize, overlap: usize) -> Result { 20 | let markdown_processor = MarkdownProcessor::new(chunk_size, overlap)?; 21 | let html_to_markdown = HtmlToMarkdownBuilder::new() 22 | .build(); 23 | Ok(HtmlProcessor { 24 | markdown_processor, 25 | html_to_markdown, 26 | }) 27 | } 28 | } 29 | 30 | impl DocumentProcessor for HtmlProcessor { 31 | fn process_document(&self, content: &str) -> Result { 32 | let content = self.html_to_markdown.convert(content)?; 33 | self.markdown_processor.process_document(&content) 34 | } 35 | } 36 | 37 | #[cfg(test)] 38 | mod tests { 39 | use crate::processor::FileProcessor; 40 | use super::*; 41 | 42 | #[test] 43 | fn test_process_html_file() { 44 | let html_processor = HtmlProcessor::new(128, 0).unwrap(); 45 | let html_file = "../test_files/test.html"; 46 | let result = html_processor.process_file(html_file); 47 | assert!(result.is_ok()); 48 | } 49 | 50 | #[test] 51 | fn test_process_html_file_err() { 52 | let html_processor = HtmlProcessor::new(128, 0).unwrap(); 53 | let html_file = "../test_files/some_file_that_doesnt_exist.html"; 54 | let result = html_processor.process_file(html_file); 55 | assert!(result.is_err()); 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /processors/src/lib.rs: -------------------------------------------------------------------------------- 1 | /// This library contains the traits and structs used to build a processor for arbitrary contents. 2 | pub mod processor; 3 | 4 | /// This module contains the file processor for different file types. 5 | pub mod pdf; 6 | 7 | /// This module contains the file processor for markdown files. 8 | pub mod markdown_processor; 9 | 10 | /// This module contains the file processor for text files. 11 | pub mod txt_processor; 12 | 13 | /// This module contains the file processor for HTML files. 14 | pub mod html_processor; 15 | 16 | /// This module contains the file processor for DOCX files. 17 | pub mod docx_processor; 18 | -------------------------------------------------------------------------------- /processors/src/markdown_processor.rs: -------------------------------------------------------------------------------- 1 | use text_splitter::{Characters, ChunkConfig, ChunkConfigError, MarkdownSplitter}; 2 | use crate::processor::{Document, DocumentProcessor}; 3 | 4 | /// A struct that provides functionality to process Markdown files. 5 | pub struct MarkdownProcessor { 6 | splitter: MarkdownSplitter 7 | } 8 | 9 | impl MarkdownProcessor { 10 | pub fn new(chunk_size: usize, overlap: usize) -> Result { 11 | let splitter_config = ChunkConfig::new(chunk_size) 12 | .with_overlap(overlap)?; 13 | let splitter = MarkdownSplitter::new(splitter_config); 14 | Ok(MarkdownProcessor { 15 | splitter 16 | }) 17 | } 18 | } 19 | 20 | impl DocumentProcessor for MarkdownProcessor { 21 | 22 | fn process_document(&self, content: &str) -> anyhow::Result { 23 | let chunks = self.splitter.chunks(content) 24 | .map(|x| x.to_string()) 25 | .collect(); 26 | Ok(Document { 27 | chunks 28 | }) 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /processors/src/pdf/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod tesseract; 2 | 3 | pub mod pdf_processor; 4 | -------------------------------------------------------------------------------- /processors/src/pdf/pdf_processor.rs: -------------------------------------------------------------------------------- 1 | use crate::markdown_processor::MarkdownProcessor; 2 | use crate::pdf::tesseract::input::{Args, Image}; 3 | use crate::processor::{Document, DocumentProcessor, FileProcessor}; 4 | use anyhow::Error; 5 | use image::DynamicImage; 6 | use pdf2image::{Pages, RenderOptionsBuilder, PDF}; 7 | use std::path::Path; 8 | use text_splitter::ChunkConfigError; 9 | 10 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 11 | pub enum PdfBackend { 12 | LoPdf, 13 | } 14 | 15 | /// A struct for processing PDF files. 16 | pub struct PdfProcessor { 17 | markdown_processor: MarkdownProcessor, 18 | ocr_config: OcrConfig, 19 | backend: PdfBackend, 20 | } 21 | 22 | pub struct OcrConfig { 23 | pub use_ocr: bool, 24 | pub tesseract_path: Option, 25 | } 26 | 27 | impl PdfProcessor { 28 | pub fn new( 29 | chunk_size: usize, 30 | overlap: usize, 31 | ocr_config: OcrConfig, 32 | backend: PdfBackend, 33 | ) -> Result { 34 | let markdown_processor = MarkdownProcessor::new(chunk_size, overlap)?; 35 | Ok(PdfProcessor { 36 | markdown_processor, 37 | ocr_config, 38 | backend, 39 | }) 40 | } 41 | } 42 | 43 | impl FileProcessor for PdfProcessor { 44 | fn process_file(&self, path: impl AsRef) -> anyhow::Result { 45 | let content = if self.ocr_config.use_ocr { 46 | let tesseract_path = self.ocr_config.tesseract_path.as_deref(); 47 | extract_text_with_ocr(&path, tesseract_path)? 48 | } else { 49 | match self.backend { 50 | PdfBackend::LoPdf => { 51 | pdf_extract::extract_text(path.as_ref()).map_err(|e| anyhow::anyhow!(e))? 52 | } 53 | } 54 | }; 55 | 56 | self.markdown_processor.process_document(&content) 57 | } 58 | } 59 | 60 | fn get_images_from_pdf>(file_path: &T) -> Result, Error> { 61 | let pdf = PDF::from_file(file_path)?; 62 | let page_count = pdf.page_count(); 63 | let pages = pdf.render( 64 | Pages::Range(1..=page_count), 65 | RenderOptionsBuilder::default().build()?, 66 | )?; 67 | Ok(pages) 68 | } 69 | 70 | fn extract_text_from_image(image: &DynamicImage, args: &Args) -> Result { 71 | let image = Image::from_dynamic_image(image)?; 72 | let text = crate::pdf::tesseract::command::image_to_string(&image, args)?; 73 | Ok(text) 74 | } 75 | 76 | fn extract_text_with_ocr>( 77 | file_path: &T, 78 | tesseract_path: Option<&str>, 79 | ) -> Result { 80 | let images = get_images_from_pdf(file_path)?; 81 | let texts: Result, Error> = images 82 | .iter() 83 | .map(|image| extract_text_from_image(image, &Args::default().with_path(tesseract_path))) 84 | .collect(); 85 | 86 | // Join the texts and clean up empty lines 87 | let text = texts?.join("\n"); 88 | let cleaned_text = text 89 | .lines() 90 | .filter(|line| !line.trim().is_empty()) 91 | .collect::>() 92 | .join("\n"); 93 | 94 | Ok(cleaned_text) 95 | } 96 | 97 | #[cfg(test)] 98 | mod tests { 99 | use super::*; 100 | use std::fs::File; 101 | use tempdir::TempDir; 102 | 103 | #[test] 104 | fn test_extract_text() { 105 | let temp_dir = TempDir::new("example").unwrap(); 106 | let pdf_file = temp_dir.path().join("test.pdf"); 107 | let processor = PdfProcessor::new( 108 | 128, 109 | 0, 110 | OcrConfig { 111 | use_ocr: false, 112 | tesseract_path: None, 113 | }, 114 | PdfBackend::LoPdf, 115 | ) 116 | .unwrap(); 117 | 118 | File::create(pdf_file).unwrap(); 119 | 120 | let pdf_file = "../test_files/test.pdf"; 121 | let text = processor.process_file(pdf_file).unwrap(); 122 | assert_eq!(text.chunks.len(), 4271); 123 | } 124 | 125 | #[test] 126 | fn test_extract_text_with_ocr() { 127 | let pdf_file = "../test_files/test.pdf"; 128 | let path = Path::new(pdf_file); 129 | 130 | // Check if the path exists 131 | if !path.exists() { 132 | panic!("File does not exist: {}", path.display()); 133 | } 134 | 135 | // Print the absolute path 136 | println!("Absolute path: {}", path.canonicalize().unwrap().display()); 137 | 138 | let text = extract_text_with_ocr(&pdf_file, None).unwrap(); 139 | 140 | println!("Text: {}", text); 141 | } 142 | } 143 | -------------------------------------------------------------------------------- /processors/src/pdf/tesseract/command.rs: -------------------------------------------------------------------------------- 1 | use input::{Args, Image}; 2 | 3 | use super::*; 4 | use std::process::{Command, Stdio}; 5 | use std::string::ToString; 6 | 7 | #[cfg(target_os = "windows")] 8 | use std::os::windows::process::CommandExt; 9 | use crate::pdf::tesseract::error::{TessError, TessResult}; 10 | 11 | #[cfg(target_os = "windows")] 12 | const CREATE_NO_WINDOW: u32 = 0x08000000; 13 | 14 | pub(crate) fn get_tesseract_command(path: Option<&str>) -> Command { 15 | if let Some(path) = path { 16 | Command::new(path) 17 | } else { 18 | let tesseract = if cfg!(target_os = "windows") { 19 | "tesseract.exe" 20 | } else { 21 | "tesseract" 22 | }; 23 | 24 | Command::new(tesseract) 25 | } 26 | } 27 | 28 | pub fn get_tesseract_version() -> TessResult { 29 | let mut command = get_tesseract_command(None); 30 | command.arg("--version"); 31 | 32 | run_tesseract_command(&mut command) 33 | } 34 | 35 | pub fn get_tesseract_langs() -> TessResult> { 36 | let mut command = get_tesseract_command(None); 37 | command.arg("--list-langs"); 38 | 39 | let output = run_tesseract_command(&mut command)?; 40 | let langs = output.lines().skip(1).map(|x| x.into()).collect(); 41 | Ok(langs) 42 | } 43 | 44 | pub(crate) fn run_tesseract_command(command: &mut Command) -> TessResult { 45 | if cfg!(debug_assertions) { 46 | show_command(command); 47 | } 48 | 49 | #[cfg(target_os = "windows")] 50 | command.creation_flags(CREATE_NO_WINDOW); 51 | 52 | let child = command 53 | .stdout(Stdio::piped()) 54 | .stderr(Stdio::piped()) 55 | .spawn() 56 | .map_err(|_| TessError::TesseractNotFoundError)?; 57 | 58 | let output = child 59 | .wait_with_output() 60 | .map_err(|_| TessError::TesseractNotFoundError)?; 61 | 62 | let out = String::from_utf8(output.stdout).unwrap(); 63 | let err = String::from_utf8(output.stderr).unwrap(); 64 | let status = output.status; 65 | 66 | match status.code() { 67 | Some(0) => Ok(out), 68 | _ => Err(TessError::CommandExitStatusError(status.to_string(), err)), 69 | } 70 | } 71 | 72 | fn show_command(command: &Command) { 73 | let params: Vec = command 74 | .get_args() 75 | .map(|x| x.to_str().unwrap_or("")) 76 | .map(|s| s.to_string()) 77 | .collect(); 78 | 79 | println!( 80 | "Tesseract Command: {} {}", 81 | command.get_program().to_str().unwrap(), 82 | params.join(" ") 83 | ); 84 | } 85 | 86 | pub fn image_to_string(image: &Image, args: &Args) -> TessResult { 87 | let mut command = create_tesseract_command(image, args)?; 88 | let output = run_tesseract_command(&mut command)?; 89 | 90 | Ok(output) 91 | } 92 | 93 | pub(crate) fn create_tesseract_command(image: &Image, args: &Args) -> TessResult { 94 | let path = args.path.clone(); 95 | let mut command = get_tesseract_command(path.as_deref()); 96 | command 97 | .arg(image.get_image_path()?) 98 | .arg("stdout") 99 | .arg("-l") 100 | .arg(args.lang.clone()); 101 | 102 | if let Some(dpi) = args.dpi { 103 | command.arg("--dpi").arg(dpi.to_string()); 104 | } 105 | 106 | if let Some(psm) = args.psm { 107 | command.arg("--psm").arg(psm.to_string()); 108 | } 109 | 110 | if let Some(oem) = args.oem { 111 | command.arg("--oem").arg(oem.to_string()); 112 | } 113 | 114 | for parameter in args.get_config_variable_args() { 115 | command.arg("-c").arg(parameter); 116 | } 117 | 118 | Ok(command) 119 | } 120 | 121 | #[cfg(test)] 122 | mod tests { 123 | use crate::pdf::tesseract::command::get_tesseract_langs; 124 | 125 | #[test] 126 | fn test_get_tesseract_langs() { 127 | let langs = get_tesseract_langs().unwrap(); 128 | 129 | assert!(langs.contains(&"eng".into())); 130 | } 131 | } 132 | -------------------------------------------------------------------------------- /processors/src/pdf/tesseract/error.rs: -------------------------------------------------------------------------------- 1 | use thiserror::Error; 2 | 3 | #[derive(Error, Debug, PartialEq)] 4 | pub enum TessError { 5 | #[error("Tesseract not found. Please check installation path!")] 6 | TesseractNotFoundError, 7 | 8 | #[error("Command ExitStatusError\n{0}")] 9 | CommandExitStatusError(String, String), 10 | 11 | #[error( 12 | "Image format not within the list of allowed image formats:\n\ 13 | ['JPEG','JPG','PNG','PBM','PGM','PPM','TIFF','BMP','GIF','WEBP']" 14 | )] 15 | ImageFormatError, 16 | 17 | #[error("Please assign a valid image path.")] 18 | ImageNotFoundError, 19 | 20 | #[error("Could not parse {0}.")] 21 | ParseError(String), 22 | 23 | #[error("Could not create tempfile.\n{0}")] 24 | TempfileError(String), 25 | 26 | #[error("Could not save dynamic image to tempfile.\n{0}")] 27 | DynamicImageError(String), 28 | } 29 | 30 | pub type TessResult = Result; 31 | -------------------------------------------------------------------------------- /processors/src/pdf/tesseract/input.rs: -------------------------------------------------------------------------------- 1 | use image::DynamicImage; 2 | use std::{ 3 | collections::HashMap, 4 | fmt::{self}, 5 | path::{Path, PathBuf}, 6 | }; 7 | use crate::pdf::tesseract::error::{TessError, TessResult}; 8 | 9 | #[derive(Clone, Debug, PartialEq)] 10 | pub struct Args { 11 | pub lang: String, 12 | pub config_variables: HashMap, 13 | pub dpi: Option, 14 | pub psm: Option, 15 | pub oem: Option, 16 | pub path: Option, 17 | } 18 | 19 | impl Default for Args { 20 | fn default() -> Self { 21 | Args { 22 | lang: "eng".into(), 23 | config_variables: HashMap::new(), 24 | dpi: Some(150), 25 | psm: Some(3), 26 | oem: Some(3), 27 | path: None, 28 | } 29 | } 30 | } 31 | 32 | impl Args { 33 | pub fn with_path(mut self, path: Option<&str>) -> Self { 34 | self.path = path.map(|p| p.to_string()); 35 | self 36 | } 37 | 38 | pub(crate) fn get_config_variable_args(&self) -> Vec { 39 | self.config_variables 40 | .iter() 41 | .map(|(key, value)| format!("{}={}", key, value)) 42 | .collect::>() 43 | } 44 | } 45 | 46 | #[derive(Debug)] 47 | pub struct Image { 48 | data: InputData, 49 | } 50 | 51 | impl Image { 52 | pub fn from_path>(path: P) -> TessResult { 53 | let path = path.into(); 54 | Self::check_image_format(&path)?; 55 | Ok(Self { 56 | data: InputData::Path(path), 57 | }) 58 | } 59 | 60 | fn check_image_format(path: &Path) -> TessResult<()> { 61 | let binding = path 62 | .extension() 63 | .ok_or(TessError::ImageFormatError)? 64 | .to_str() 65 | .ok_or(TessError::ImageFormatError)? 66 | .to_uppercase(); 67 | if matches!( 68 | binding.as_str(), 69 | "JPEG" | "JPG" | "PNG" | "PBM" | "PGM" | "PPM" | "TIFF" | "BMP" | "GIF" | "WEBP" 70 | ) { 71 | Ok(()) 72 | } else { 73 | Err(TessError::ImageFormatError) 74 | } 75 | } 76 | 77 | pub fn from_dynamic_image(image: &DynamicImage) -> TessResult { 78 | //Store Image as Tempfile 79 | let tempfile = tempfile::Builder::new() 80 | .prefix("rusty-tesseract") 81 | .suffix(".png") 82 | .tempfile() 83 | .map_err(|e| TessError::TempfileError(e.to_string()))?; 84 | let path = tempfile.path(); 85 | image 86 | .save_with_format(path, image::ImageFormat::Png) 87 | .map_err(|e| TessError::DynamicImageError(e.to_string()))?; 88 | 89 | Ok(Self { 90 | data: InputData::Image(tempfile), 91 | }) 92 | } 93 | 94 | pub fn get_image_path(&self) -> TessResult<&str> { 95 | match &self.data { 96 | InputData::Path(x) => x.to_str(), 97 | InputData::Image(x) => x.path().to_str(), 98 | } 99 | .ok_or(TessError::ImageNotFoundError) 100 | } 101 | } 102 | 103 | #[derive(Debug)] 104 | enum InputData { 105 | Path(PathBuf), 106 | Image(tempfile::NamedTempFile), 107 | } 108 | 109 | impl fmt::Display for Image { 110 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 111 | write!(f, "{}", self.get_image_path().unwrap()) 112 | } 113 | } 114 | 115 | #[cfg(test)] 116 | mod tests { 117 | use super::Image; 118 | use image::ImageReader; 119 | 120 | #[test] 121 | fn test_from_path() { 122 | let input = Image::from_path("../test_files/clip/cat1.jpg").unwrap(); 123 | 124 | assert_eq!(input.get_image_path().unwrap(), "../test_files/clip/cat1.jpg") 125 | } 126 | 127 | #[test] 128 | fn test_from_dynamic_image() { 129 | let img = ImageReader::open("../test_files/clip/cat1.jpg") 130 | .unwrap() 131 | .decode() 132 | .unwrap(); 133 | 134 | let input = Image::from_dynamic_image(&img).unwrap(); 135 | 136 | let temppath = input.get_image_path().unwrap(); 137 | 138 | let tempimg = ImageReader::open(temppath).unwrap().decode().unwrap(); 139 | 140 | assert_eq!(img, tempimg); 141 | } 142 | } 143 | -------------------------------------------------------------------------------- /processors/src/pdf/tesseract/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod command; 2 | pub mod error; 3 | pub mod input; 4 | pub mod output_boxes; 5 | pub mod output_config_parameters; 6 | pub mod output_data; 7 | pub mod parse_line_util; 8 | -------------------------------------------------------------------------------- /processors/src/pdf/tesseract/output_boxes.rs: -------------------------------------------------------------------------------- 1 | use core::fmt; 2 | use crate::pdf::tesseract::error::TessResult; 3 | use crate::pdf::tesseract::input::{Args, Image}; 4 | use crate::pdf::tesseract::parse_line_util::{parse_next, FromLine}; 5 | 6 | #[derive(Debug, PartialEq)] 7 | pub struct BoxOutput { 8 | pub output: String, 9 | pub boxes: Vec, 10 | } 11 | 12 | impl fmt::Display for BoxOutput { 13 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 14 | write!(f, "{}", self.output) 15 | } 16 | } 17 | 18 | #[derive(Debug, PartialEq)] 19 | pub struct Box { 20 | pub symbol: String, 21 | pub left: i32, 22 | pub bottom: i32, 23 | pub right: i32, 24 | pub top: i32, 25 | pub page: i32, 26 | } 27 | 28 | impl fmt::Display for Box { 29 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 30 | write!( 31 | f, 32 | "{} {} {} {} {} {}", 33 | self.symbol, self.left, self.bottom, self.right, self.top, self.page 34 | ) 35 | } 36 | } 37 | 38 | impl FromLine for Box { 39 | fn from_line(line: &str) -> Option { 40 | let mut x = line.split_whitespace(); 41 | 42 | Some(Box { 43 | symbol: x.next()?.to_string(), 44 | left: parse_next(&mut x)?, 45 | bottom: parse_next(&mut x)?, 46 | right: parse_next(&mut x)?, 47 | top: parse_next(&mut x)?, 48 | page: parse_next(&mut x)?, 49 | }) 50 | } 51 | } 52 | 53 | pub fn image_to_boxes(image: &Image, args: &Args) -> TessResult { 54 | let mut command = crate::pdf::tesseract::command::create_tesseract_command(image, args)?; 55 | command.arg("makebox"); 56 | 57 | let output = crate::pdf::tesseract::command::run_tesseract_command(&mut command)?; 58 | let boxes = string_to_boxes(&output)?; 59 | Ok(BoxOutput { output, boxes }) 60 | } 61 | 62 | fn string_to_boxes(output: &str) -> TessResult> { 63 | output.lines().map(Box::parse).collect::<_>() 64 | } 65 | 66 | #[cfg(test)] 67 | mod tests { 68 | use crate::pdf::tesseract::{ 69 | error::TessError, 70 | output_boxes::{string_to_boxes, Box}, 71 | }; 72 | 73 | #[test] 74 | fn test_string_to_boxes() { 75 | let result = string_to_boxes("L 18 26 36 59 0"); 76 | assert_eq!( 77 | *result.unwrap().first().unwrap(), 78 | Box { 79 | symbol: String::from("L"), 80 | left: 18, 81 | bottom: 26, 82 | right: 36, 83 | top: 59, 84 | page: 0 85 | } 86 | ) 87 | } 88 | 89 | #[test] 90 | fn test_string_to_boxes_parse_error() { 91 | let result = string_to_boxes("L 18 X 36 59 0"); 92 | assert_eq!( 93 | result, 94 | Err(TessError::ParseError( 95 | "invalid line 'L 18 X 36 59 0'".into() 96 | )) 97 | ) 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /processors/src/pdf/tesseract/output_config_parameters.rs: -------------------------------------------------------------------------------- 1 | use parse_line_util::FromLine; 2 | 3 | use super::*; 4 | use core::fmt; 5 | 6 | #[derive(Debug, PartialEq)] 7 | pub struct ConfigParameterOutput { 8 | pub output: String, 9 | pub config_parameters: Vec, 10 | } 11 | 12 | impl fmt::Display for ConfigParameterOutput { 13 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 14 | write!(f, "{}", self.output) 15 | } 16 | } 17 | 18 | #[derive(Debug, PartialEq)] 19 | pub struct ConfigParameter { 20 | pub name: String, 21 | pub default_value: String, 22 | pub description: String, 23 | } 24 | 25 | impl fmt::Display for ConfigParameter { 26 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 27 | write!( 28 | f, 29 | "{} {} {}", 30 | self.name, self.default_value, self.description, 31 | ) 32 | } 33 | } 34 | 35 | impl FromLine for ConfigParameter { 36 | fn from_line(line: &str) -> Option { 37 | let (name, x) = line.split_once("\t")?; 38 | let (default_value, description) = x.split_once("\t")?; 39 | 40 | Some(ConfigParameter { 41 | name: name.into(), 42 | default_value: default_value.into(), 43 | description: description.into(), 44 | }) 45 | } 46 | } 47 | 48 | pub fn get_tesseract_config_parameters( 49 | ) -> error::TessResult { 50 | let mut command = command::get_tesseract_command(None); 51 | command.arg("--print-parameters"); 52 | 53 | let output = command::run_tesseract_command(&mut command)?; 54 | 55 | let config_parameters = string_to_config_parameter_output(&output)?; 56 | 57 | Ok(ConfigParameterOutput { 58 | output, 59 | config_parameters, 60 | }) 61 | } 62 | 63 | fn string_to_config_parameter_output( 64 | output: &str, 65 | ) -> crate::pdf::tesseract::error::TessResult> { 66 | output 67 | .lines() 68 | .skip(1) 69 | .map(ConfigParameter::parse) 70 | .collect::<_>() 71 | } 72 | 73 | #[cfg(test)] 74 | mod tests { 75 | use crate::pdf::tesseract::output_config_parameters::{string_to_config_parameter_output, ConfigParameter}; 76 | 77 | #[test] 78 | fn test_string_to_config_parameter_output() { 79 | let result = string_to_config_parameter_output( 80 | "Tesseract parameters:\n\ 81 | log_level\t2147483647\tLogging level\n\ 82 | textord_dotmatrix_gap\t3\t pixel gap for broken pixed pitch\n\ 83 | textord_debug_block\t0\tBlock to do debug on\n\ 84 | textord_pitch_range\t2\tMax range test on pitch", 85 | ) 86 | .unwrap(); 87 | 88 | let expected = ConfigParameter { 89 | name: "log_level".into(), 90 | default_value: "2147483647".into(), 91 | description: "Logging level".into(), 92 | }; 93 | 94 | assert_eq!(result.first().unwrap(), &expected); 95 | } 96 | 97 | #[test] 98 | fn test_get_tesseract_config_parameters() { 99 | let result = 100 | crate::pdf::tesseract::output_config_parameters::get_tesseract_config_parameters().unwrap(); 101 | let x = result 102 | .config_parameters 103 | .iter() 104 | .find(|&x| x.name == "tessedit_char_whitelist") 105 | .unwrap(); 106 | 107 | let expected = ConfigParameter { 108 | name: "tessedit_char_whitelist".into(), 109 | default_value: "".into(), 110 | description: "Whitelist of chars to recognize".into(), 111 | }; 112 | 113 | assert_eq!(*x, expected); 114 | } 115 | 116 | #[test] 117 | fn test_string_to_config_parameter_output_parse_error() { 118 | let result = string_to_config_parameter_output( 119 | "Tesseract parameters:\n\ 120 | log_level\t2147483647\tLogging level\n\ 121 | Test\n\ 122 | textord_debug_block\t0\tBlock to do debug on\n\ 123 | textord_pitch_range\t2\tMax range test on pitch", 124 | ); 125 | assert_eq!( 126 | result, 127 | Err(crate::pdf::tesseract::error::TessError::ParseError( 128 | "invalid line 'Test'".into() 129 | )) 130 | ) 131 | } 132 | } 133 | -------------------------------------------------------------------------------- /processors/src/pdf/tesseract/output_data.rs: -------------------------------------------------------------------------------- 1 | use input::{Args, Image}; 2 | use parse_line_util::{parse_next, FromLine}; 3 | 4 | use super::*; 5 | use core::fmt; 6 | 7 | #[derive(Debug, PartialEq)] 8 | pub struct DataOutput { 9 | pub output: String, 10 | pub data: Vec, 11 | } 12 | 13 | impl fmt::Display for DataOutput { 14 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 15 | write!(f, "{}", self.output) 16 | } 17 | } 18 | 19 | #[derive(Debug, PartialEq)] 20 | pub struct Data { 21 | pub level: i32, 22 | pub page_num: i32, 23 | pub block_num: i32, 24 | pub par_num: i32, 25 | pub line_num: i32, 26 | pub word_num: i32, 27 | pub left: i32, 28 | pub top: i32, 29 | pub width: i32, 30 | pub height: i32, 31 | pub conf: f32, 32 | pub text: String, 33 | } 34 | 35 | impl fmt::Display for Data { 36 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 37 | write!( 38 | f, 39 | "{} {} {} {} {} {} {} {} {} {} {} {}", 40 | self.level, 41 | self.page_num, 42 | self.block_num, 43 | self.par_num, 44 | self.line_num, 45 | self.word_num, 46 | self.left, 47 | self.top, 48 | self.width, 49 | self.height, 50 | self.conf, 51 | self.text, 52 | ) 53 | } 54 | } 55 | 56 | impl FromLine for Data { 57 | fn from_line(line: &str) -> Option { 58 | let mut x = line.split_whitespace(); 59 | Some(Data { 60 | level: parse_next(&mut x)?, 61 | page_num: parse_next(&mut x)?, 62 | block_num: parse_next(&mut x)?, 63 | par_num: parse_next(&mut x)?, 64 | line_num: parse_next(&mut x)?, 65 | word_num: parse_next(&mut x)?, 66 | left: parse_next(&mut x)?, 67 | top: parse_next(&mut x)?, 68 | width: parse_next(&mut x)?, 69 | height: parse_next(&mut x)?, 70 | conf: parse_next(&mut x)?, 71 | text: x.next().unwrap_or("").to_string(), 72 | }) 73 | } 74 | } 75 | 76 | pub fn image_to_data( 77 | image: &Image, 78 | args: &Args, 79 | ) -> error::TessResult { 80 | let mut command = command::create_tesseract_command(image, args)?; 81 | command.arg("tsv"); 82 | 83 | let output = command::run_tesseract_command(&mut command)?; 84 | 85 | let data = string_to_data(&output)?; 86 | 87 | Ok(DataOutput { output, data }) 88 | } 89 | 90 | fn string_to_data(output: &str) -> error::TessResult> { 91 | output.lines().skip(1).map(Data::parse).collect::<_>() 92 | } 93 | 94 | #[cfg(test)] 95 | mod tests { 96 | use crate::pdf::tesseract::output_data::{string_to_data, Data}; 97 | 98 | #[test] 99 | fn test_string_to_data() { 100 | let result = string_to_data("level page_num block_num par_num line_num word_num left top width height conf text 101 | 5 1 1 1 1 1 65 41 46 20 96.063751 The"); 102 | assert_eq!( 103 | *result.unwrap().first().unwrap(), 104 | Data { 105 | level: 5, 106 | page_num: 1, 107 | block_num: 1, 108 | par_num: 1, 109 | line_num: 1, 110 | word_num: 1, 111 | left: 65, 112 | top: 41, 113 | width: 46, 114 | height: 20, 115 | conf: 96.063751, 116 | text: String::from("The"), 117 | } 118 | ) 119 | } 120 | 121 | 122 | #[test] 123 | fn test_string_to_data_parse_error() { 124 | let result = string_to_data("level page_num block_num par_num line_num word_num left top width height conf text\n\ 125 | Test"); 126 | assert_eq!( 127 | result, 128 | Err(crate::pdf::tesseract::error::TessError::ParseError( 129 | "invalid line 'Test'".into() 130 | )) 131 | ) 132 | } 133 | } 134 | -------------------------------------------------------------------------------- /processors/src/pdf/tesseract/parse_line_util.rs: -------------------------------------------------------------------------------- 1 | use crate::pdf::tesseract::error::{TessError, TessResult}; 2 | 3 | pub(crate) fn parse_next( 4 | iter: &mut std::str::SplitWhitespace<'_>, 5 | ) -> Option { 6 | iter.next()?.parse::().ok() 7 | } 8 | 9 | pub(crate) trait FromLine: Sized { 10 | fn from_line(line: &str) -> Option; 11 | 12 | fn parse(line: &str) -> TessResult { 13 | Self::from_line(line).ok_or(TessError::ParseError(format!("invalid line '{}'", line))) 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /processors/src/processor.rs: -------------------------------------------------------------------------------- 1 | use std::path::Path; 2 | 3 | pub trait DocumentProcessor { 4 | 5 | fn process_document(&self, content: &str) -> anyhow::Result; 6 | } 7 | 8 | pub trait FileProcessor { 9 | 10 | fn process_file(&self, path: impl AsRef) -> anyhow::Result; 11 | } 12 | 13 | pub trait UrlProcessor { 14 | fn process_url(&self, url: &str) -> anyhow::Result; 15 | } 16 | 17 | impl FileProcessor for T { 18 | fn process_file(&self, path: impl AsRef) -> anyhow::Result { 19 | let bytes = std::fs::read(path)?; 20 | let out = String::from_utf8_lossy(&bytes); 21 | self.process_document(&out) 22 | } 23 | } 24 | 25 | impl UrlProcessor for T { 26 | fn process_url(&self, url: &str) -> anyhow::Result { 27 | let content = reqwest::blocking::get(url)?.text()?; 28 | self.process_document(&content) 29 | } 30 | } 31 | 32 | pub struct Document { 33 | pub chunks: Vec 34 | } 35 | -------------------------------------------------------------------------------- /processors/src/txt_processor.rs: -------------------------------------------------------------------------------- 1 | use text_splitter::ChunkConfigError; 2 | use crate::markdown_processor::MarkdownProcessor; 3 | use crate::processor::{Document, DocumentProcessor}; 4 | 5 | /// A struct for processing PDF files. 6 | pub struct TxtProcessor { 7 | markdown_processor: MarkdownProcessor, 8 | } 9 | 10 | impl TxtProcessor { 11 | pub fn new(chunk_size: usize, overlap: usize) -> Result { 12 | let markdown_processor = MarkdownProcessor::new(chunk_size, overlap)?; 13 | Ok(TxtProcessor { 14 | markdown_processor, 15 | }) 16 | } 17 | } 18 | 19 | impl DocumentProcessor for TxtProcessor { 20 | fn process_document(&self, content: &str) -> anyhow::Result { 21 | self.markdown_processor.process_document(content) 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["maturin>=1.5,<2.0"] 3 | build-backend = "maturin" 4 | 5 | [project] 6 | name = "embed_anything" 7 | requires-python = ">=3.8" 8 | description = "Embed anything at lightning speed" 9 | readme = "README.md" 10 | classifiers = [ 11 | "Programming Language :: Python :: 3.8", 12 | "Programming Language :: Python :: 3.9", 13 | "Programming Language :: Python :: 3.10", 14 | "Programming Language :: Python :: 3.11", 15 | "Programming Language :: Python :: 3.12", 16 | "License :: OSI Approved :: MIT License" 17 | 18 | ] 19 | dynamic = ["version"] 20 | license = {file = "LICENSE"} 21 | dependencies = ["onnxruntime==1.20.1"] 22 | 23 | [tool.maturin] 24 | features = ["extension-module"] 25 | profile="release" 26 | python-source = "python/python" 27 | manifest-path = "python/Cargo.toml" 28 | module-name = "embed_anything._embed_anything" 29 | 30 | [project.urls] 31 | Homepage = "https://github.com/StarlightSearch/EmbedAnything/tree/main" -------------------------------------------------------------------------------- /python/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "embed_anything_python" 3 | version.workspace = true 4 | edition = "2021" 5 | 6 | [lib] 7 | name = "_embed_anything" 8 | crate-type = ["cdylib"] 9 | 10 | [dependencies] 11 | embed_anything = {path = "../rust", features = ["ort"]} 12 | processors = {path = "../processors"} 13 | pyo3 = { version = "0.23.2"} 14 | tokio = { version = "1.39.0", features = ["rt-multi-thread"]} 15 | strum = {workspace = true} 16 | strum_macros = {workspace = true} 17 | 18 | [features] 19 | extension-module = ["pyo3/extension-module"] 20 | mkl = ["embed_anything/mkl"] 21 | accelerate = ["embed_anything/accelerate"] 22 | cuda = ["embed_anything/cuda"] 23 | cudnn = ["embed_anything/cudnn"] 24 | metal = ["embed_anything/metal"] 25 | ort = ["embed_anything/ort"] 26 | audio = ["embed_anything/audio"] 27 | -------------------------------------------------------------------------------- /python/python/embed_anything/libiomp5.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StarlightSearch/EmbedAnything/82b406a8afbec9507320102f24f4ce169fade329/python/python/embed_anything/libiomp5.so -------------------------------------------------------------------------------- /python/python/embed_anything/libiomp5md.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StarlightSearch/EmbedAnything/82b406a8afbec9507320102f24f4ce169fade329/python/python/embed_anything/libiomp5md.dll -------------------------------------------------------------------------------- /python/python/embed_anything/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StarlightSearch/EmbedAnything/82b406a8afbec9507320102f24f4ce169fade329/python/python/embed_anything/py.typed -------------------------------------------------------------------------------- /python/python/embed_anything/vectordb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import uuid 4 | from typing import List, Dict 5 | from abc import ABC, abstractmethod 6 | from ._embed_anything import EmbedData 7 | 8 | 9 | class Adapter(ABC): 10 | def __init__(self, api_key: str): 11 | self.api_key = api_key 12 | 13 | @abstractmethod 14 | def create_index(self, dimension: int, metric: str, index_name: str, **kwargs): 15 | pass 16 | 17 | @abstractmethod 18 | def delete_index(self, index_name: str): 19 | pass 20 | 21 | @abstractmethod 22 | def convert(self, embeddings: List[EmbedData]) -> List[Dict]: 23 | pass 24 | 25 | @abstractmethod 26 | def upsert(self, data: List[Dict]): 27 | data = self.convert(data) 28 | pass 29 | -------------------------------------------------------------------------------- /python/src/config.rs: -------------------------------------------------------------------------------- 1 | use crate::EmbeddingModel; 2 | use embed_anything::config::SplittingStrategy; 3 | use pyo3::prelude::*; 4 | use processors::pdf::pdf_processor::PdfBackend; 5 | 6 | #[pyclass] 7 | #[derive(Default)] 8 | pub struct TextEmbedConfig { 9 | pub inner: embed_anything::config::TextEmbedConfig, 10 | } 11 | 12 | #[allow(clippy::too_many_arguments)] 13 | #[pymethods] 14 | impl TextEmbedConfig { 15 | #[new] 16 | #[pyo3(signature = (chunk_size=None, batch_size=None, late_chunking=None, buffer_size=None, overlap_ratio=None, splitting_strategy=None, semantic_encoder=None, use_ocr=None, tesseract_path=None, pdf_backend=None))] 17 | pub fn new( 18 | chunk_size: Option, 19 | batch_size: Option, 20 | late_chunking: Option, 21 | buffer_size: Option, 22 | overlap_ratio: Option, 23 | splitting_strategy: Option<&str>, 24 | semantic_encoder: Option<&EmbeddingModel>, 25 | use_ocr: Option, 26 | tesseract_path: Option<&str>, 27 | pdf_backend: Option<&str>, 28 | ) -> Self { 29 | let pdf_backend = match pdf_backend { 30 | Some(backend) => { 31 | match backend { 32 | "lopdf" => PdfBackend::LoPdf, 33 | _ => panic!("Unknown PDF backend provided!"), 34 | } 35 | } 36 | None => PdfBackend::LoPdf, 37 | }; 38 | 39 | let strategy = match splitting_strategy { 40 | Some(strategy) => { 41 | match strategy { 42 | "sentence" => SplittingStrategy::Sentence, 43 | "semantic" => { 44 | if semantic_encoder.is_none() { 45 | panic!("Semantic encoder is required when using Semantic splitting strategy"); 46 | } 47 | SplittingStrategy::Semantic { 48 | semantic_encoder: semantic_encoder.unwrap().inner.clone(), 49 | } 50 | } 51 | _ => panic!("Unknown strategy provided!"), 52 | } 53 | } 54 | None => SplittingStrategy::Sentence, 55 | }; 56 | 57 | Self { 58 | inner: embed_anything::config::TextEmbedConfig::default() 59 | .with_chunk_size(chunk_size.unwrap_or(1000), overlap_ratio) 60 | .with_batch_size(batch_size.unwrap_or(32)) 61 | .with_buffer_size(buffer_size.unwrap_or(100)) 62 | .with_splitting_strategy(strategy) 63 | .with_late_chunking(late_chunking.unwrap_or(false)) 64 | .with_ocr(use_ocr.unwrap_or(false), tesseract_path) 65 | .with_pdf_backend(pdf_backend), 66 | } 67 | } 68 | 69 | #[getter] 70 | pub fn chunk_size(&self) -> Option { 71 | self.inner.chunk_size 72 | } 73 | 74 | #[getter] 75 | pub fn batch_size(&self) -> Option { 76 | self.inner.batch_size 77 | } 78 | } 79 | 80 | #[pyclass] 81 | #[derive(Clone, Default)] 82 | pub struct ImageEmbedConfig { 83 | pub inner: embed_anything::config::ImageEmbedConfig, 84 | } 85 | 86 | #[pymethods] 87 | impl ImageEmbedConfig { 88 | #[new] 89 | #[pyo3(signature = (buffer_size=None, batch_size=None))] 90 | pub fn new(buffer_size: Option, batch_size: Option) -> Self { 91 | Self { 92 | inner: embed_anything::config::ImageEmbedConfig::new(buffer_size, batch_size), 93 | } 94 | } 95 | 96 | #[getter] 97 | pub fn buffer_size(&self) -> Option { 98 | self.inner.buffer_size 99 | } 100 | 101 | #[getter] 102 | pub fn batch_size(&self) -> Option { 103 | self.inner.batch_size 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /python/src/models/colbert.rs: -------------------------------------------------------------------------------- 1 | use std::rc::Rc; 2 | 3 | use embed_anything::embeddings::get_text_metadata; 4 | use embed_anything::embeddings::local::colbert::{ColbertEmbed, OrtColbertEmbedder}; 5 | use pyo3::exceptions::PyValueError; 6 | use pyo3::prelude::*; 7 | use pyo3::PyResult; 8 | 9 | use crate::EmbedData; 10 | 11 | #[pyclass] 12 | pub struct ColbertModel { 13 | pub model: Box, 14 | } 15 | 16 | #[pymethods] 17 | impl ColbertModel { 18 | #[new] 19 | #[pyo3(signature = (hf_model_id=None, revision=None, path_in_repo=None))] 20 | pub fn new( 21 | hf_model_id: Option<&str>, 22 | revision: Option<&str>, 23 | path_in_repo: Option<&str>, 24 | ) -> PyResult { 25 | let model = OrtColbertEmbedder::new(hf_model_id, revision, path_in_repo) 26 | .map_err(|e| PyValueError::new_err(e.to_string()))?; 27 | Ok(Self { 28 | model: Box::new(model), 29 | }) 30 | } 31 | 32 | #[staticmethod] 33 | #[pyo3(signature = (hf_model_id=None, revision=None, path_in_repo=None))] 34 | fn from_pretrained_onnx( 35 | hf_model_id: Option<&str>, 36 | revision: Option<&str>, 37 | path_in_repo: Option<&str>, 38 | ) -> PyResult { 39 | let model = OrtColbertEmbedder::new(hf_model_id, revision, path_in_repo) 40 | .map_err(|e| PyValueError::new_err(e.to_string()))?; 41 | Ok(Self { 42 | model: Box::new(model), 43 | }) 44 | } 45 | 46 | #[pyo3(signature = (text_batch, batch_size=None, is_doc=true))] 47 | pub fn embed( 48 | &self, 49 | text_batch: Vec, 50 | batch_size: Option, 51 | is_doc: bool, 52 | ) -> PyResult> { 53 | let text_batch = text_batch.iter().map(|s| s.as_str()).collect::>(); 54 | 55 | let embed_data = self 56 | .model 57 | .embed(&text_batch, batch_size, is_doc) 58 | .map_err(|e| PyValueError::new_err(e.to_string()))?; 59 | let embeddings = get_text_metadata(&Rc::new(embed_data), &text_batch, &None) 60 | .map_err(|e| PyValueError::new_err(e.to_string()))?; 61 | Ok(embeddings 62 | .into_iter() 63 | .map(|data| EmbedData { inner: data }) 64 | .collect()) 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /python/src/models/colpali.rs: -------------------------------------------------------------------------------- 1 | use embed_anything::embeddings::local::colpali::ColPaliEmbed; 2 | use embed_anything::embeddings::local::colpali::ColPaliEmbedder; 3 | use embed_anything::embeddings::local::colpali_ort::OrtColPaliEmbedder; 4 | use pyo3::exceptions::PyValueError; 5 | use pyo3::prelude::*; 6 | use pyo3::PyResult; 7 | 8 | use crate::EmbedData; 9 | #[pyclass] 10 | pub struct ColpaliModel { 11 | pub model: Box, 12 | } 13 | 14 | #[pymethods] 15 | impl ColpaliModel { 16 | #[new] 17 | #[pyo3(signature = (model_id, revision=None))] 18 | pub fn new(model_id: &str, revision: Option<&str>) -> PyResult { 19 | let model = ColPaliEmbedder::new(model_id, revision) 20 | .map_err(|e| PyValueError::new_err(e.to_string()))?; 21 | Ok(Self { 22 | model: Box::new(model), 23 | }) 24 | } 25 | 26 | #[staticmethod] 27 | #[pyo3(signature = (model_id, revision=None))] 28 | pub fn from_pretrained(model_id: &str, revision: Option<&str>) -> PyResult { 29 | let model = ColPaliEmbedder::new(model_id, revision) 30 | .map_err(|e| PyValueError::new_err(e.to_string()))?; 31 | Ok(Self { 32 | model: Box::new(model), 33 | }) 34 | } 35 | 36 | #[staticmethod] 37 | #[pyo3(signature = (model_id, revision=None))] 38 | pub fn from_pretrained_onnx(model_id: &str, revision: Option<&str>) -> PyResult { 39 | let model = OrtColPaliEmbedder::new(model_id, revision) 40 | .map_err(|e| PyValueError::new_err(e.to_string()))?; 41 | Ok(Self { 42 | model: Box::new(model), 43 | }) 44 | } 45 | 46 | pub fn embed_file(&self, file_path: &str, batch_size: usize) -> PyResult> { 47 | let embed_data = self 48 | .model 49 | .embed_file(file_path.into(), batch_size) 50 | .map_err(|e| PyValueError::new_err(e.to_string()))?; 51 | Ok(embed_data 52 | .into_iter() 53 | .map(|data| EmbedData { inner: data }) 54 | .collect()) 55 | } 56 | 57 | pub fn embed_query(&self, query: &str) -> PyResult> { 58 | let embed_data = self 59 | .model 60 | .embed_query(query) 61 | .map_err(|e| PyValueError::new_err(e.to_string()))?; 62 | Ok(embed_data 63 | .into_iter() 64 | .map(|data| EmbedData { inner: data }) 65 | .collect()) 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /python/src/models/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod colbert; 2 | pub mod colpali; 3 | pub mod reranker; 4 | -------------------------------------------------------------------------------- /python/src/models/reranker.rs: -------------------------------------------------------------------------------- 1 | use pyo3::exceptions::PyValueError; 2 | use pyo3::prelude::*; 3 | use pyo3::PyResult; 4 | 5 | #[pyclass] 6 | pub struct Reranker { 7 | pub model: embed_anything::reranker::model::Reranker, 8 | } 9 | 10 | #[pyclass(eq, eq_int)] 11 | #[derive(PartialEq)] 12 | 13 | pub enum Dtype { 14 | F16, 15 | INT8, 16 | Q4, 17 | UINT8, 18 | BNB4, 19 | Q4F16, 20 | F32, 21 | } 22 | 23 | #[pyclass] 24 | pub struct RerankerResult { 25 | pub inner: embed_anything::reranker::model::RerankerResult, 26 | } 27 | 28 | #[pyclass] 29 | pub struct DocumentRank { 30 | pub document: String, 31 | pub relevance_score: f32, 32 | pub rank: usize, 33 | } 34 | 35 | #[pymethods] 36 | impl DocumentRank { 37 | #[getter(document)] 38 | fn document(&self) -> String { 39 | self.document.clone() 40 | } 41 | 42 | #[getter(relevance_score)] 43 | fn relevance_score(&self) -> f32 { 44 | self.relevance_score 45 | } 46 | 47 | #[getter(rank)] 48 | fn rank(&self) -> usize { 49 | self.rank 50 | } 51 | 52 | fn __str__(&self) -> String { 53 | format!( 54 | "{{\"document\": \"{}\", \"relevance_score\": {}, \"rank\": {}}}", 55 | self.document, self.relevance_score, self.rank 56 | ) 57 | } 58 | 59 | fn __repr__(&self) -> String { 60 | format!( 61 | "DocumentRank(document={}, relevance_score={}, rank={})", 62 | self.document, self.relevance_score, self.rank 63 | ) 64 | } 65 | } 66 | 67 | #[pymethods] 68 | impl RerankerResult { 69 | #[getter(query)] 70 | fn query(&self) -> String { 71 | self.inner.query.clone() 72 | } 73 | 74 | #[getter(documents)] 75 | fn documents(&self) -> Vec { 76 | self.inner 77 | .documents 78 | .clone() 79 | .into_iter() 80 | .map(|d| DocumentRank { 81 | document: d.document, 82 | relevance_score: d.relevance_score, 83 | rank: d.rank, 84 | }) 85 | .collect() 86 | } 87 | 88 | fn __str__(&self) -> String { 89 | format!( 90 | "Query: {}\nDocuments: {}", 91 | self.query(), 92 | self.documents() 93 | .iter() 94 | .map(|d| format!( 95 | "Document: {}, Relevance Score: {}, Rank: {}", 96 | d.document, d.relevance_score, d.rank 97 | )) 98 | .collect::>() 99 | .join(", ") 100 | ) 101 | } 102 | } 103 | 104 | #[pymethods] 105 | impl Reranker { 106 | #[staticmethod] 107 | #[pyo3(signature = (model_id, revision=None, dtype=None))] 108 | pub fn from_pretrained( 109 | model_id: &str, 110 | revision: Option<&str>, 111 | dtype: Option<&Dtype>, 112 | ) -> PyResult { 113 | let dtype = match dtype { 114 | Some(Dtype::F16) => embed_anything::Dtype::F16, 115 | Some(Dtype::INT8) => embed_anything::Dtype::INT8, 116 | Some(Dtype::Q4) => embed_anything::Dtype::Q4, 117 | Some(Dtype::UINT8) => embed_anything::Dtype::UINT8, 118 | Some(Dtype::BNB4) => embed_anything::Dtype::BNB4, 119 | Some(Dtype::F32) => embed_anything::Dtype::F32, 120 | _ => embed_anything::Dtype::F32, 121 | }; 122 | let model = embed_anything::reranker::model::Reranker::new(model_id, revision, dtype) 123 | .map_err(|e| PyValueError::new_err(e.to_string()))?; 124 | Ok(Self { model }) 125 | } 126 | 127 | #[pyo3(signature = (query, documents, batch_size))] 128 | pub fn rerank( 129 | &self, 130 | query: Vec, 131 | documents: Vec, 132 | batch_size: usize, 133 | ) -> PyResult> { 134 | let query_refs: Vec<&str> = query.iter().map(|s| s.as_str()).collect(); 135 | let document_refs: Vec<&str> = documents.iter().map(|s| s.as_str()).collect(); 136 | let results = self 137 | .model 138 | .rerank(query_refs, document_refs, batch_size) 139 | .unwrap(); 140 | Ok(results 141 | .into_iter() 142 | .map(|r| RerankerResult { inner: r }) 143 | .collect::>()) 144 | } 145 | } 146 | -------------------------------------------------------------------------------- /rust/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "embed_anything" 3 | version.workspace = true 4 | edition.workspace = true 5 | license.workspace = true 6 | description.workspace = true 7 | repository.workspace = true 8 | authors.workspace = true 9 | readme = "../README.md" 10 | 11 | [dependencies] 12 | # File preprocessing 13 | processors = { path = "../processors" } 14 | 15 | # Data Serialization 16 | serde = { version = "1.0.196", features = ["derive"] } 17 | serde_json = "1.0.112" 18 | 19 | # HTTP Client 20 | reqwest = { version = "0.12.2", default-features = false, features = ["json", "blocking"] } 21 | 22 | # Filesystem 23 | walkdir = "2.4.0" 24 | 25 | # Regular Expressions 26 | regex = "1.10.3" 27 | 28 | # Parallelism 29 | rayon = "1.8.1" 30 | 31 | # Natural Language Processing 32 | tokenizers = {version="0.21.1", default-features = false, features=["http"]} 33 | text-splitter = {version= "0.25.1", features=["tokenizers", "markdown"]} 34 | 35 | tracing = "0.1.41" 36 | 37 | # Hugging Face Libraries 38 | hf-hub = { version = "0.4.1", default-features = false } 39 | candle-nn = { workspace = true } 40 | candle-transformers = { workspace = true } 41 | candle-core = { workspace = true } 42 | 43 | # Error Handling 44 | anyhow = "1.0.89" 45 | 46 | # Asynchronous Programming 47 | tokio = { version = "1.40.0", features = ["macros", "rt-multi-thread"] } 48 | 49 | chrono = "0.4.38" 50 | rand = "0.9.0" 51 | itertools = "0.14.0" 52 | 53 | # Image processing 54 | image = "0.25.6" 55 | 56 | # Audio Processing 57 | symphonia = { version = "0.5.3", optional = true, features = ["all"] } 58 | byteorder = "1.5.0" 59 | 60 | ndarray = "0.16.1" 61 | pdf2image = "0.1.2" 62 | strum = {workspace = true} 63 | base64 = "0.22.1" 64 | # Optional Dependency 65 | intel-mkl-src = { version = "0.8.1", optional = true } 66 | accelerate-src = { version = "0.3.2", optional = true } 67 | indicatif = "0.17.8" 68 | statistical = "1.0.0" 69 | half = "2.4.1" 70 | candle-flash-attn = { workspace = true, optional = true } 71 | model2vec-rs = "0.1.1" 72 | 73 | # Logging 74 | log = "0.4" 75 | 76 | [dev-dependencies] 77 | tempdir = "0.3.7" 78 | lazy_static = "1.4.0" 79 | clap = { version = "4.5.20", features = ["derive"] } 80 | 81 | [target.'cfg(not(target_os = "macos"))'.dependencies] 82 | ort = {version = "=2.0.0-rc.9", features = ["load-dynamic"], optional = true} 83 | 84 | [target.'cfg(target_os = "macos")'.dependencies] 85 | ort = {version = "=2.0.0-rc.9", optional = true} 86 | 87 | 88 | [features] 89 | default = ['rustls-tls'] 90 | mkl = ["dep:intel-mkl-src", "candle-nn/mkl", "candle-transformers/mkl", "candle-core/mkl"] 91 | accelerate = ["dep:accelerate-src", "candle-core/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"] 92 | cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda", "ort/cuda"] 93 | cudnn = ["candle-core/cudnn"] 94 | flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"] 95 | metal = ["candle-core/metal", "candle-nn/metal"] 96 | audio = ["dep:symphonia"] 97 | ort = ["dep:ort"] 98 | rustls-tls = [ 99 | "reqwest/rustls-tls", 100 | "hf-hub/rustls-tls", 101 | "tokenizers/rustls-tls" 102 | ] 103 | -------------------------------------------------------------------------------- /rust/examples/audio.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use embed_anything::{ 4 | config::{SplittingStrategy, TextEmbedConfig}, 5 | emb_audio, 6 | embeddings::embed::EmbedderBuilder, 7 | file_processor::audio::audio_processor::AudioDecoderModel, 8 | }; 9 | 10 | #[tokio::main] 11 | async fn main() { 12 | let audio_path = std::path::PathBuf::from("test_files/audio/samples_hp0.wav"); 13 | let mut audio_decoder = AudioDecoderModel::from_pretrained( 14 | Some("openai/whisper-tiny.en"), 15 | Some("main"), 16 | "tiny-en", 17 | false, 18 | ) 19 | .unwrap(); 20 | 21 | let bert_model = Arc::new( 22 | EmbedderBuilder::new() 23 | .model_architecture("bert") 24 | .model_id(Some("sentence-transformers/all-MiniLM-L6-v2")) 25 | .revision(None) 26 | .token(None) 27 | .from_pretrained_hf() 28 | .unwrap(), 29 | ); 30 | 31 | let text_embed_config = TextEmbedConfig::default() 32 | .with_chunk_size(1000, Some(0.3)) 33 | .with_batch_size(32) 34 | .with_splitting_strategy(SplittingStrategy::Sentence); 35 | 36 | let embeddings = emb_audio( 37 | audio_path, 38 | &mut audio_decoder, 39 | &bert_model, 40 | Some(&text_embed_config), 41 | ) 42 | .await 43 | .unwrap() 44 | .unwrap(); 45 | 46 | println!("{:?}", embeddings); 47 | } 48 | -------------------------------------------------------------------------------- /rust/examples/bert.rs: -------------------------------------------------------------------------------- 1 | use embed_anything::config::{SplittingStrategy, TextEmbedConfig}; 2 | use embed_anything::embeddings::embed::EmbedderBuilder; 3 | use embed_anything::Dtype; 4 | use std::collections::HashSet; 5 | use std::sync::Arc; 6 | use std::{path::PathBuf, time::Instant}; 7 | 8 | #[tokio::main] 9 | async fn main() { 10 | let model = Arc::new( 11 | EmbedderBuilder::new() 12 | .model_architecture("jina") 13 | .model_id(Some("jinaai/jina-embeddings-v2-small-en")) 14 | .revision(None) 15 | .token(None) 16 | .dtype(Some(Dtype::F16)) 17 | .from_pretrained_hf() 18 | .unwrap(), 19 | ); 20 | 21 | let config = TextEmbedConfig::default() 22 | .with_chunk_size(1000, Some(0.3)) 23 | .with_batch_size(32) 24 | .with_buffer_size(32) 25 | .with_splitting_strategy(SplittingStrategy::Semantic { 26 | semantic_encoder: model.clone(), 27 | }); 28 | 29 | let now = Instant::now(); 30 | 31 | // Embed files batch 32 | let _out_2 = model 33 | .embed_files_batch( 34 | vec!["test_files/test.pdf", "test_files/test.txt"], 35 | Some(&config), 36 | None, 37 | ) 38 | .await 39 | .unwrap() 40 | .unwrap(); 41 | 42 | // Embed file 43 | let _out = model 44 | .embed_file("test_files/test.pdf", Some(&config), None) 45 | .await 46 | .unwrap() 47 | .unwrap(); 48 | 49 | let elapsed_time: std::time::Duration = now.elapsed(); 50 | 51 | println!("Elapsed Time: {}", elapsed_time.as_secs_f32()); 52 | 53 | let now = Instant::now(); 54 | 55 | // Embed a directory 56 | let _out = model 57 | .embed_directory_stream( 58 | PathBuf::from("test_files"), 59 | Some(vec!["pdf".to_string(), "txt".to_string()]), 60 | Some(&config), 61 | None, 62 | ) 63 | .await 64 | .unwrap() 65 | .unwrap(); 66 | 67 | // Embed an html file 68 | let _out2 = model 69 | .embed_webpage( 70 | "https://www.google.com".to_string(), 71 | Some(&config), 72 | None, 73 | ) 74 | .await 75 | .unwrap() 76 | .unwrap(); 77 | 78 | let embedded_files = _out 79 | .iter() 80 | .map(|e| { 81 | e.metadata 82 | .as_ref() 83 | .unwrap() 84 | .get("file_name") 85 | .unwrap() 86 | .clone() 87 | }) 88 | .collect::>(); 89 | let mut embedded_files_set = HashSet::new(); 90 | embedded_files_set.extend(embedded_files); 91 | println!("Embedded files: {:?}", embedded_files_set); 92 | 93 | println!("Number of chunks: {:?}", _out.len()); 94 | let elapsed_time: std::time::Duration = now.elapsed(); 95 | println!("Elapsed Time: {}", elapsed_time.as_secs_f32()); 96 | } 97 | -------------------------------------------------------------------------------- /rust/examples/clip.rs: -------------------------------------------------------------------------------- 1 | use candle_core::{Device, Tensor}; 2 | use embed_anything::{ 3 | embed_image_directory, embed_query, 4 | embeddings::embed::{Embedder, EmbedderBuilder}, 5 | }; 6 | use std::{path::PathBuf, sync::Arc, time::Instant}; 7 | 8 | #[tokio::main] 9 | async fn main() { 10 | let now = Instant::now(); 11 | 12 | let model = EmbedderBuilder::new() 13 | .model_architecture("clip") 14 | .model_id(Some("google/siglip-base-patch16-224")) 15 | .revision(None) 16 | .token(None) 17 | .from_pretrained_hf() 18 | .unwrap(); 19 | let model: Arc = Arc::new(model); 20 | let out = embed_image_directory(PathBuf::from("test_files"), &model, None, None) 21 | .await 22 | .unwrap() 23 | .unwrap(); 24 | 25 | 26 | let query_emb_data = embed_query(&["Photo of a monkey?"], &model, None) 27 | .await 28 | .unwrap(); 29 | let n_vectors = out.len(); 30 | 31 | let vector = out 32 | .iter() 33 | .map(|embed| embed.embedding.clone()) 34 | .collect::>() 35 | .into_iter() 36 | .flat_map(|x| x.to_dense().unwrap()) 37 | .collect::>(); 38 | 39 | let out_embeddings = Tensor::from_vec( 40 | vector, 41 | (n_vectors, out[0].embedding.to_dense().unwrap().len()), 42 | &Device::Cpu, 43 | ) 44 | .unwrap(); 45 | 46 | let image_paths = out 47 | .iter() 48 | .map(|embed| embed.text.clone().unwrap()) 49 | .collect::>(); 50 | 51 | let query_embeddings = Tensor::from_vec( 52 | query_emb_data 53 | .iter() 54 | .map(|embed| embed.embedding.clone()) 55 | .collect::>() 56 | .into_iter() 57 | .flat_map(|x| x.to_dense().unwrap()) 58 | .collect::>(), 59 | (1, query_emb_data[0].embedding.to_dense().unwrap().len()), 60 | &Device::Cpu, 61 | ) 62 | .unwrap(); 63 | 64 | let similarities = out_embeddings 65 | .matmul(&query_embeddings.transpose(0, 1).unwrap()) 66 | .unwrap() 67 | .detach() 68 | .squeeze(1) 69 | .unwrap() 70 | .to_vec1::() 71 | .unwrap(); 72 | 73 | let mut indices: Vec = (0..similarities.len()).collect(); 74 | indices.sort_by(|a, b| similarities[*b].partial_cmp(&similarities[*a]).unwrap()); 75 | 76 | println!("Descending order of similarity: "); 77 | for idx in &indices { 78 | println!("{}", image_paths[*idx]); 79 | } 80 | 81 | println!("-----------"); 82 | 83 | println!("Most similar image: {}", image_paths[indices[0]]); 84 | 85 | let elapsed_time = now.elapsed(); 86 | println!("Elapsed Time: {}", elapsed_time.as_secs_f32()); 87 | } 88 | -------------------------------------------------------------------------------- /rust/examples/cloud.rs: -------------------------------------------------------------------------------- 1 | use std::{path::PathBuf, sync::Arc}; 2 | 3 | use embed_anything::{ 4 | config::TextEmbedConfig, 5 | embed_directory_stream, embed_file, 6 | embeddings::embed::Embedder, 7 | }; 8 | 9 | use anyhow::Result; 10 | 11 | #[tokio::main] 12 | async fn main() -> Result<()> { 13 | let text_embed_config = TextEmbedConfig::default() 14 | .with_chunk_size(1000, Some(0.3)) 15 | .with_batch_size(512) 16 | .with_buffer_size(512); 17 | let cohere_model = 18 | Embedder::from_pretrained_cloud("cohere", "embed-english-v3.0", None).unwrap(); // You can add your api key here 19 | let openai_model = 20 | Embedder::from_pretrained_cloud("openai", "text-embedding-3-small", None).unwrap(); // You can add your api key here 21 | let openai_model: Arc = Arc::new(openai_model); 22 | let _openai_embeddings = embed_directory_stream( 23 | PathBuf::from("test_files"), 24 | &openai_model, 25 | Some(vec!["pdf".to_string()]), 26 | Some(&text_embed_config), 27 | None, 28 | ) 29 | .await? 30 | .unwrap(); 31 | 32 | let _file_embedding = embed_file( 33 | "test_files/attention.pdf", 34 | &openai_model, 35 | Some(&text_embed_config), 36 | None, 37 | ) 38 | .await? 39 | .unwrap(); 40 | 41 | let _cohere_embedding = embed_file( 42 | "test_files/attention.pdf", 43 | &cohere_model, 44 | Some(&text_embed_config), 45 | None, 46 | ) 47 | .await? 48 | .unwrap(); 49 | 50 | println!("Cohere embedding: {:?}", _cohere_embedding); 51 | 52 | Ok(()) 53 | } 54 | -------------------------------------------------------------------------------- /rust/examples/cohere_pdf.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use embed_anything::{ 4 | config::TextEmbedConfig, 5 | embeddings::embed::Embedder, 6 | }; 7 | 8 | use anyhow::Result; 9 | 10 | #[tokio::main] 11 | async fn main() -> Result<()> { 12 | let text_embed_config = TextEmbedConfig::default() 13 | .with_chunk_size(1000, Some(0.3)) 14 | .with_batch_size(8) 15 | .with_buffer_size(8); 16 | let cohere_model: Arc = Arc::new( 17 | Embedder::from_pretrained_cloud("cohere-vision", "embed-v4.0", None).unwrap() 18 | ); 19 | 20 | // let embeddings = cohere_model.embed_query(&["What are Positional Encodings"], Some(&text_embed_config)).await?; 21 | let embeddings = cohere_model.embed_file("test_files/colpali.pdf", Some(&text_embed_config), None).await?; 22 | println!("{:?}", embeddings.unwrap().len()); 23 | Ok(()) 24 | } 25 | -------------------------------------------------------------------------------- /rust/examples/colbert.rs: -------------------------------------------------------------------------------- 1 | use embed_anything::config::{SplittingStrategy, TextEmbedConfig}; 2 | use embed_anything::embeddings::embed::{Embedder, EmbeddingResult}; 3 | use embed_anything::{embed_file, embed_query}; 4 | use rayon::prelude::*; 5 | use std::sync::Arc; 6 | use std::time::Instant; 7 | 8 | #[tokio::main] 9 | async fn main() -> Result<(), anyhow::Error> { 10 | let model = Arc::new( 11 | Embedder::from_pretrained_onnx( 12 | "colbert", 13 | // Some(ONNXModel::ModernBERTBase), 14 | None, 15 | None, 16 | Some("answerdotai/answerai-colbert-small-v1"), 17 | None, 18 | Some("onnx/model_fp16.onnx"), 19 | ) 20 | .unwrap(), 21 | ); 22 | 23 | let config = TextEmbedConfig::default() 24 | .with_chunk_size(1000, Some(0.3)) 25 | .with_batch_size(32) 26 | .with_buffer_size(1000) 27 | .with_splitting_strategy(SplittingStrategy::Sentence); 28 | 29 | // get files in bench 30 | let files = std::fs::read_dir("bench") 31 | .unwrap() 32 | .map(|f| f.unwrap().path()) 33 | .collect::>(); 34 | 35 | let now = Instant::now(); 36 | 37 | let futures = files 38 | .par_iter() 39 | .map(|file| embed_file(file, &model, Some(&config), None)) 40 | .collect::>(); 41 | 42 | let _data = futures.into_iter().next().unwrap().await?.unwrap(); 43 | 44 | let elapsed_time = now.elapsed(); 45 | println!("Elapsed Time: {}", elapsed_time.as_secs_f32()); 46 | 47 | let sentences = [ 48 | "The quick brown fox jumps over the lazy dog", 49 | "The cat is sleeping on the mat", 50 | "The dog is barking at the moon", 51 | "I love pizza", 52 | "The dog is sitting in the park", 53 | "Der Hund sitzt im Park", // German for "The dog is sitting in the park" 54 | "pizza is the best", 55 | "मैं पिज्जा पसंद करता हूं", // Hindi for "I like pizza" 56 | ]; 57 | 58 | let doc_embeddings = embed_query(&sentences, &model, Some(&config)) 59 | .await 60 | .unwrap(); 61 | 62 | // print out the embeddings for the first sentence 63 | let EmbeddingResult::MultiVector(vec) = doc_embeddings[0].embedding.clone() else { 64 | panic!("output should be a multi vector"); 65 | }; 66 | for (i, v) in vec.iter().enumerate() { 67 | println!("{}: {:?}", i, v); 68 | } 69 | 70 | Ok(()) 71 | } 72 | -------------------------------------------------------------------------------- /rust/examples/colpali.rs: -------------------------------------------------------------------------------- 1 | use clap::{Parser, ValueEnum}; 2 | use embed_anything::embeddings::local::colpali::{ColPaliEmbed, ColPaliEmbedder}; 3 | 4 | #[cfg(feature = "ort")] 5 | use embed_anything::embeddings::local::colpali_ort::OrtColPaliEmbedder; 6 | 7 | #[derive(Parser, Debug, Clone, ValueEnum)] 8 | enum ModelType { 9 | Ort, 10 | Normal, 11 | } 12 | 13 | #[derive(Parser, Debug)] 14 | #[command(author, version, about, long_about = None)] 15 | struct Args { 16 | /// Choose model type: 'ort' or 'normal' 17 | #[arg(short, long, default_value = "normal")] 18 | model_type: ModelType, 19 | } 20 | 21 | fn main() -> Result<(), anyhow::Error> { 22 | let args = Args::parse(); 23 | 24 | let colpali_model = match args.model_type { 25 | ModelType::Ort => { 26 | #[cfg(feature = "ort")] 27 | { 28 | Box::new(OrtColPaliEmbedder::new( 29 | "akshayballal/colpali-v1.2-merged-onnx", 30 | None, 31 | )?) as Box 32 | } 33 | #[cfg(not(feature = "ort"))] 34 | { 35 | panic!("ORT is not supported without ORT"); 36 | } 37 | } 38 | ModelType::Normal => Box::new(ColPaliEmbedder::new("vidore/colpali-v1.2-merged", None)?) 39 | as Box, 40 | }; 41 | // ... rest of the code ... 42 | let file_path = "test_files/attention.pdf"; 43 | let batch_size = 4; 44 | let embed_data = colpali_model.embed_file(file_path.into(), batch_size)?; 45 | println!("{:?}", embed_data.len()); 46 | 47 | let prompt = "What is attention?"; 48 | let query_embeddings = colpali_model.embed_query(prompt)?; 49 | println!("{:?}", query_embeddings.len()); 50 | Ok(()) 51 | } 52 | -------------------------------------------------------------------------------- /rust/examples/late_chunking.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use embed_anything::{config::TextEmbedConfig, embeddings::embed::EmbedderBuilder}; 4 | 5 | #[tokio::main] 6 | async fn main() { 7 | let model = Arc::new( 8 | EmbedderBuilder::new() 9 | .model_architecture("jina") 10 | .model_id(Some("jinaai/jina-embeddings-v2-small-en")) 11 | .revision(None) 12 | .path_in_repo(Some("model.onnx")) 13 | .from_pretrained_onnx() 14 | .unwrap(), 15 | ); 16 | 17 | let config = TextEmbedConfig::default() 18 | .with_chunk_size(1000, Some(0.3)) 19 | .with_batch_size(4) 20 | .with_buffer_size(32) 21 | .with_late_chunking(true); 22 | 23 | let out = model 24 | .embed_file("test_files/test.pdf", Some(&config), None) 25 | .await 26 | .unwrap() 27 | .unwrap(); 28 | 29 | for d in out { 30 | println!("{}", d.text.unwrap()); 31 | println!("---"); 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /rust/examples/model2vec.rs: -------------------------------------------------------------------------------- 1 | use embed_anything::config::{SplittingStrategy, TextEmbedConfig}; 2 | use embed_anything::embeddings::embed::EmbedderBuilder; 3 | use processors::pdf::pdf_processor::PdfBackend; 4 | use embed_anything::Dtype; 5 | use std::collections::HashSet; 6 | use std::sync::Arc; 7 | use std::{path::PathBuf, time::Instant}; 8 | 9 | #[tokio::main] 10 | async fn main() { 11 | let model = Arc::new( 12 | EmbedderBuilder::new() 13 | .model_architecture("model2vec") 14 | .model_id(Some("minishlab/potion-base-8M")) 15 | .revision(None) 16 | .token(None) 17 | .dtype(Some(Dtype::F16)) 18 | .from_pretrained_hf() 19 | .unwrap(), 20 | ); 21 | 22 | let config = TextEmbedConfig::default() 23 | .with_chunk_size(1000, Some(0.3)) 24 | .with_batch_size(32) 25 | .with_buffer_size(32) 26 | .with_splitting_strategy(SplittingStrategy::Semantic { 27 | semantic_encoder: model.clone(), 28 | }) 29 | .with_pdf_backend(PdfBackend::LoPdf); 30 | 31 | let now = Instant::now(); 32 | 33 | // Embed files batch 34 | let _out_2 = model 35 | .embed_files_batch( 36 | vec!["test_files/test.pdf", "test_files/test.txt"], 37 | Some(&config), 38 | None, 39 | ) 40 | .await 41 | .unwrap() 42 | .unwrap(); 43 | 44 | // Embed file 45 | let _out = model 46 | .embed_file("test_files/test.pdf", Some(&config), None) 47 | .await 48 | .unwrap() 49 | .unwrap(); 50 | 51 | let elapsed_time: std::time::Duration = now.elapsed(); 52 | 53 | println!("Elapsed Time: {}", elapsed_time.as_secs_f32()); 54 | 55 | let now = Instant::now(); 56 | 57 | // Embed a directory 58 | let _out = model 59 | .embed_directory_stream( 60 | PathBuf::from("test_files"), 61 | Some(vec!["pdf".to_string(), "txt".to_string()]), 62 | Some(&config), 63 | None, 64 | ) 65 | .await 66 | .unwrap() 67 | .unwrap(); 68 | 69 | // Embed an html file 70 | let _out2 = model 71 | .embed_webpage( 72 | "https://www.google.com".to_string(), 73 | Some(&config), 74 | None, 75 | ) 76 | .await 77 | .unwrap() 78 | .unwrap(); 79 | 80 | let embedded_files = _out 81 | .iter() 82 | .map(|e| { 83 | e.metadata 84 | .as_ref() 85 | .unwrap() 86 | .get("file_name") 87 | .unwrap() 88 | .clone() 89 | }) 90 | .collect::>(); 91 | let mut embedded_files_set = HashSet::new(); 92 | embedded_files_set.extend(embedded_files); 93 | println!("Embedded files: {:?}", embedded_files_set); 94 | 95 | println!("Number of chunks: {:?}", _out.len()); 96 | let elapsed_time: std::time::Duration = now.elapsed(); 97 | println!("Elapsed Time: {}", elapsed_time.as_secs_f32()); 98 | } 99 | -------------------------------------------------------------------------------- /rust/examples/ort_models.rs: -------------------------------------------------------------------------------- 1 | use candle_core::{Device, Tensor}; 2 | use embed_anything::config::{SplittingStrategy, TextEmbedConfig}; 3 | use embed_anything::embeddings::embed::EmbedderBuilder; 4 | use embed_anything::embeddings::local::text_embedding::ONNXModel; 5 | use embed_anything::{embed_file, embed_query}; 6 | use rayon::prelude::*; 7 | use std::sync::Arc; 8 | use std::time::Instant; 9 | 10 | #[tokio::main] 11 | async fn main() -> Result<(), anyhow::Error> { 12 | let model = Arc::new( 13 | EmbedderBuilder::new() 14 | .model_architecture("bert") 15 | .onnx_model_id(Some(ONNXModel::ModernBERTBase)) 16 | .from_pretrained_onnx() 17 | .unwrap(), 18 | ); 19 | 20 | let config = TextEmbedConfig::default() 21 | .with_chunk_size(1000, Some(0.3)) 22 | .with_batch_size(32) 23 | .with_buffer_size(256) 24 | .with_splitting_strategy(SplittingStrategy::Sentence); 25 | 26 | // get files in bench 27 | let files = std::fs::read_dir("bench") 28 | .unwrap() 29 | .map(|f| f.unwrap().path()) 30 | .collect::>(); 31 | 32 | let now = Instant::now(); 33 | 34 | let futures = files 35 | .par_iter() 36 | .map(|file| embed_file(file, &model, Some(&config), None)) 37 | .collect::>(); 38 | 39 | let _data = futures.into_iter().next().unwrap().await?.unwrap(); 40 | 41 | for chunk in _data { 42 | println!("--------------------------------"); 43 | 44 | println!("{:?}", chunk.text.unwrap()); 45 | println!("\n"); 46 | } 47 | 48 | let elapsed_time = now.elapsed(); 49 | println!("Elapsed Time: {}", elapsed_time.as_secs_f32()); 50 | 51 | let sentences = [ 52 | "The quick brown fox jumps over the lazy dog", 53 | "The cat is sleeping on the mat", 54 | "The dog is barking at the moon", 55 | "I love pizza", 56 | "The dog is sitting in the park", 57 | "Der Hund sitzt im Park", // German for "The dog is sitting in the park" 58 | "pizza is the best", 59 | "मैं पिज्जा पसंद करता हूं", // Hindi for "I like pizza" 60 | ]; 61 | let doc_embeddings = embed_query(&sentences, &model, Some(&config)) 62 | .await 63 | .unwrap(); 64 | let n_vectors = doc_embeddings.len(); 65 | let out_embeddings = Tensor::from_vec( 66 | doc_embeddings 67 | .iter() 68 | .map(|embed| embed.embedding.clone()) 69 | .collect::>() 70 | .into_iter() 71 | .flat_map(|x| x.to_dense().unwrap()) 72 | .collect::>(), 73 | ( 74 | n_vectors, 75 | doc_embeddings[0].embedding.to_dense().unwrap().len(), 76 | ), 77 | &Device::Cpu, 78 | ) 79 | .unwrap(); 80 | 81 | let mut similarities = vec![]; 82 | for i in 0..n_vectors { 83 | let e_i = out_embeddings.get(i)?; 84 | for j in (i + 1)..n_vectors { 85 | let e_j = out_embeddings.get(j)?; 86 | let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::()?; 87 | let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::()?; 88 | let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::()?; 89 | let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt(); 90 | similarities.push((cosine_similarity, i, j)) 91 | } 92 | } 93 | similarities.sort_by(|u, v| v.0.total_cmp(&u.0)); 94 | for &(score, i, j) in similarities[..5].iter() { 95 | println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j]) 96 | } 97 | 98 | Ok(()) 99 | } 100 | -------------------------------------------------------------------------------- /rust/examples/reranker.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "ort")] 2 | fn main() { 3 | use embed_anything::Dtype; 4 | 5 | let reranker = embed_anything::reranker::model::Reranker::new( 6 | "jinaai/jina-reranker-v2-base-multilingual", 7 | None, 8 | Dtype::F16, 9 | ) 10 | .unwrap(); 11 | 12 | let sentences = vec![ 13 | "The cat sits outside", 14 | "A man is playing guitar", 15 | "I love pasta", 16 | "The new movie is awesome", 17 | "The cat plays in the garden", 18 | "A woman watches TV", 19 | "The new movie is so great", 20 | "Do you like pizza?", 21 | ]; 22 | 23 | let query = vec!["There is a cat outside"]; 24 | 25 | let reranker_results = reranker.rerank(query, sentences, 32).unwrap(); 26 | let pretty_results = serde_json::to_string_pretty(&reranker_results).unwrap(); 27 | println!("{}", pretty_results); 28 | } 29 | 30 | #[cfg(not(feature = "ort"))] 31 | fn main() { 32 | println!("Reranker is not supported without ORT"); 33 | } 34 | -------------------------------------------------------------------------------- /rust/examples/splade.rs: -------------------------------------------------------------------------------- 1 | use clap::{Parser, ValueEnum}; 2 | 3 | use candle_core::{Device, Tensor}; 4 | use embed_anything::{ 5 | config::{SplittingStrategy, TextEmbedConfig}, 6 | embed_query, 7 | embeddings::{ 8 | embed::{Embedder, EmbedderBuilder}, 9 | local::text_embedding::ONNXModel, 10 | }, 11 | }; 12 | use std::sync::Arc; 13 | 14 | #[derive(Parser, Debug, Clone, ValueEnum)] 15 | enum ModelType { 16 | Ort, 17 | Normal, 18 | } 19 | 20 | #[derive(Parser, Debug)] 21 | #[command(author, version, about, long_about = None)] 22 | struct Args { 23 | /// Choose model type: 'ort' or 'normal' 24 | #[arg(short, long, default_value = "normal")] 25 | model_type: ModelType, 26 | } 27 | 28 | #[tokio::main] 29 | async fn main() -> anyhow::Result<()> { 30 | let args = Args::parse(); 31 | 32 | let model = match args.model_type { 33 | ModelType::Ort => Arc::new( 34 | Embedder::from_pretrained_onnx( 35 | "sparse-bert", 36 | Some(ONNXModel::SPLADEPPENV2), 37 | None, 38 | None, 39 | None, 40 | None, 41 | ) 42 | .unwrap(), 43 | ), 44 | ModelType::Normal => Arc::new( 45 | EmbedderBuilder::new() 46 | .model_architecture("sparse-bert") 47 | .model_id(Some("prithivida/Splade_PP_en_v1")) 48 | .revision(None) 49 | .from_pretrained_hf() 50 | .unwrap(), 51 | ), 52 | }; 53 | 54 | let config = TextEmbedConfig::default() 55 | .with_chunk_size(1000, Some(0.3)) 56 | .with_batch_size(32) 57 | .with_buffer_size(100) 58 | .with_splitting_strategy(SplittingStrategy::Sentence); 59 | 60 | let sentences = [ 61 | "The cat sits outside", 62 | "A man is playing guitar", 63 | "I love pasta", 64 | "The new movie is awesome", 65 | "The cat plays in the garden", 66 | "A woman watches TV", 67 | "The new movie is so great", 68 | "Do you like pizza?", 69 | ]; 70 | 71 | let n_sentences = sentences.len(); 72 | 73 | let out = embed_query(&sentences, &model, Some(&config)) 74 | .await 75 | .unwrap(); 76 | 77 | let embeddings = out 78 | .iter() 79 | .flat_map(|embed| embed.embedding.to_dense().unwrap()) 80 | .collect::>(); 81 | 82 | let embeddings_tensor = Tensor::from_vec( 83 | embeddings.clone(), 84 | (n_sentences, out[0].embedding.to_dense().unwrap().len()), 85 | &Device::Cpu, 86 | ) 87 | .unwrap(); 88 | 89 | let mut similarities = vec![]; 90 | for i in 0..n_sentences { 91 | let e_i = embeddings_tensor.get(i).unwrap(); 92 | for j in (i + 1)..n_sentences { 93 | let e_j = embeddings_tensor.get(j).unwrap(); 94 | let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::()?; 95 | let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::()?; 96 | let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::()?; 97 | let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt(); 98 | similarities.push((cosine_similarity, i, j)) 99 | } 100 | } 101 | println!("similarities: {:?}", similarities); 102 | similarities.sort_by(|u, v| v.0.total_cmp(&u.0)); 103 | for &(score, i, j) in similarities[..5].iter() { 104 | println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j]) 105 | } 106 | 107 | Ok(()) 108 | } 109 | -------------------------------------------------------------------------------- /rust/examples/web_embed.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use candle_core::Tensor; 4 | use embed_anything::{ 5 | config::{SplittingStrategy, TextEmbedConfig}, 6 | embed_query, embed_webpage, 7 | embeddings::embed::EmbedderBuilder, 8 | }; 9 | 10 | #[tokio::main] 11 | async fn main() { 12 | let start_time = std::time::Instant::now(); 13 | let url = "https://www.scrapingbee.com/blog/web-scraping-rust/".to_string(); 14 | 15 | let embedder = Arc::new( 16 | EmbedderBuilder::new() 17 | .model_architecture("bert") 18 | .model_id(Some("sentence-transformers/all-MiniLM-L6-v2")) 19 | .revision(None) 20 | .from_pretrained_hf() 21 | .unwrap(), 22 | ); 23 | 24 | let embed_config = TextEmbedConfig::default() 25 | .with_chunk_size(1000, Some(0.3)) 26 | .with_batch_size(32) 27 | .with_buffer_size(100) 28 | .with_splitting_strategy(SplittingStrategy::Sentence); 29 | 30 | let embed_data = embed_webpage(url, &embedder, Some(&embed_config), None) 31 | .await 32 | .unwrap() 33 | .unwrap(); 34 | let embeddings = embed_data 35 | .iter() 36 | .map(|data| data.embedding.to_dense().unwrap()) 37 | .collect::>(); 38 | 39 | // Convert embeddings to a tensor 40 | let embeddings = Tensor::from_vec( 41 | embeddings.iter().flatten().cloned().collect::>(), 42 | (embeddings.len(), embeddings[0].len()), 43 | &candle_core::Device::Cpu, 44 | ) 45 | .unwrap(); 46 | 47 | let query = ["Installation on Windows"]; 48 | let query_embedding: Vec = embed_query(&query, &embedder, Some(&embed_config)) 49 | .await 50 | .unwrap() 51 | .iter() 52 | .flat_map(|data| data.embedding.to_dense().unwrap()) 53 | .collect(); 54 | 55 | let query_embedding_tensor = Tensor::from_vec( 56 | query_embedding.clone(), 57 | (1, query_embedding.len()), 58 | &candle_core::Device::Cpu, 59 | ) 60 | .unwrap(); 61 | 62 | let similarities = embeddings 63 | .matmul(&query_embedding_tensor.transpose(0, 1).unwrap()) 64 | .unwrap() 65 | .detach() 66 | .squeeze(1) 67 | .unwrap() 68 | .to_vec1::() 69 | .unwrap(); 70 | 71 | let max_similarity_index = similarities 72 | .iter() 73 | .enumerate() 74 | .max_by(|a, b| a.1.partial_cmp(b.1).unwrap()) 75 | .unwrap() 76 | .0; 77 | let data = &embed_data[max_similarity_index].text.as_ref().unwrap(); 78 | 79 | println!("{}", data); 80 | println!("Time taken: {:?}", start_time.elapsed()); 81 | } 82 | -------------------------------------------------------------------------------- /rust/src/chunkers/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod cumulative; 2 | pub mod statistical; 3 | -------------------------------------------------------------------------------- /rust/src/embeddings/cloud/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod cohere; 2 | pub mod openai; 3 | -------------------------------------------------------------------------------- /rust/src/embeddings/cloud/openai.rs: -------------------------------------------------------------------------------- 1 | use reqwest::Client; 2 | use serde::Deserialize; 3 | use serde_json::json; 4 | 5 | use crate::embeddings::embed::EmbeddingResult; 6 | 7 | #[derive(Deserialize, Debug, Default)] 8 | pub struct OpenAIEmbedResponse { 9 | pub data: Vec, 10 | pub model: String, 11 | pub usage: Usage, 12 | } 13 | 14 | #[derive(Deserialize, Debug, Default)] 15 | pub struct EmbeddingData { 16 | pub embedding: Vec, 17 | pub index: usize, 18 | } 19 | 20 | #[derive(Deserialize, Debug, Default)] 21 | pub struct Usage { 22 | pub prompt_tokens: usize, 23 | pub total_tokens: usize, 24 | } 25 | 26 | /// Represents an OpenAIEmbeder struct that contains the URL and API key for making requests to the OpenAI API. 27 | #[derive(Debug)] 28 | pub struct OpenAIEmbedder { 29 | url: String, 30 | model: String, 31 | api_key: String, 32 | client: Client, 33 | } 34 | 35 | impl Default for OpenAIEmbedder { 36 | fn default() -> Self { 37 | Self::new("text-embedding-3-small".to_string(), None) 38 | } 39 | } 40 | 41 | impl OpenAIEmbedder { 42 | pub fn new(model: String, api_key: Option) -> Self { 43 | let api_key = 44 | api_key.unwrap_or_else(|| std::env::var("OPENAI_API_KEY").expect("API Key not set")); 45 | 46 | Self { 47 | model, 48 | url: "https://api.openai.com/v1/embeddings".to_string(), 49 | api_key, 50 | client: Client::new(), 51 | } 52 | } 53 | 54 | pub async fn embed(&self, text_batch: &[&str]) -> Result, anyhow::Error> { 55 | let response = self 56 | .client 57 | .post(&self.url) 58 | .header("Content-Type", "application/json") 59 | .header("Authorization", format!("Bearer {}", self.api_key)) 60 | .json(&json!({ 61 | "input": text_batch, 62 | "model": self.model, 63 | "encoding_format": "float" 64 | })) 65 | .send() 66 | .await?; 67 | let data = response.json::().await?; 68 | let encodings = data 69 | .data 70 | .iter() 71 | .map(|data| EmbeddingResult::DenseVector(data.embedding.clone())) 72 | .collect::>(); 73 | 74 | Ok(encodings) 75 | } 76 | } 77 | 78 | #[cfg(test)] 79 | mod tests { 80 | use super::*; 81 | 82 | #[tokio::test] 83 | async fn test_openai_embed() { 84 | let openai = OpenAIEmbedder::default(); 85 | let response = openai 86 | .client 87 | .post(&openai.url) 88 | .header("Content-Type", "application/json") 89 | .header("Authorization", format!("Bearer {}", openai.api_key)) 90 | .json(&json!({ 91 | "input": vec!["Hello world"], 92 | "model": openai.model, 93 | "encoding_format": "float" 94 | })) 95 | .send() 96 | .await 97 | .unwrap(); 98 | // println!("{}", response.text().await.unwrap()); 99 | let data = response.json::().await.unwrap(); 100 | println!("{:?}", data); 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /rust/src/embeddings/local/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod bert; 2 | pub mod clip; 3 | #[cfg(feature = "ort")] 4 | pub mod colbert; 5 | pub mod colpali; 6 | #[cfg(feature = "ort")] 7 | pub mod colpali_ort; 8 | pub mod jina; 9 | pub mod model_info; 10 | pub mod modernbert; 11 | #[cfg(feature = "ort")] 12 | pub mod ort_bert; 13 | #[cfg(feature = "ort")] 14 | pub mod ort_jina; 15 | pub mod pooling; 16 | pub mod text_embedding; 17 | pub mod model2vec; 18 | -------------------------------------------------------------------------------- /rust/src/embeddings/local/model2vec.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Error as E; 2 | use model2vec_rs; 3 | 4 | use crate::embeddings::embed::EmbeddingResult; 5 | 6 | pub struct Model2VecEmbedder { 7 | pub model: model2vec_rs::model::StaticModel, 8 | } 9 | 10 | impl Model2VecEmbedder { 11 | pub fn new(model_id: &str, token: Option<&str>, path_in_repo: Option<&str>) -> Result { 12 | let model = 13 | model2vec_rs::model::StaticModel::from_pretrained(model_id, token, None, path_in_repo)?; 14 | Ok(Self { model }) 15 | } 16 | 17 | pub fn embed( 18 | &self, 19 | text_batch: &[&str], 20 | _batch_size: Option, 21 | ) -> Result, E> { 22 | let embeddings = self.model.encode( 23 | text_batch 24 | .iter() 25 | .map(|s| s.to_string()) 26 | .collect::>() 27 | .as_slice(), 28 | ); 29 | let embeddings = embeddings 30 | .iter() 31 | .map(|e| EmbeddingResult::DenseVector(e.clone())) 32 | .collect(); 33 | Ok(embeddings) 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /rust/src/embeddings/local/model_info.rs: -------------------------------------------------------------------------------- 1 | /// Data struct about the available models 2 | #[derive(Debug, Clone)] 3 | pub struct ModelInfo { 4 | pub model: T, 5 | pub dim: usize, 6 | pub description: String, 7 | pub hf_model_id: String, 8 | pub model_code: String, 9 | pub model_file: String, 10 | } 11 | -------------------------------------------------------------------------------- /rust/src/embeddings/local/pooling.rs: -------------------------------------------------------------------------------- 1 | use candle_core::Tensor; 2 | use ndarray::prelude::*; 3 | use ndarray::{Array2, Array3}; 4 | use std::ops::Mul; 5 | 6 | #[derive(Debug, Clone)] 7 | pub enum Pooling { 8 | Mean, 9 | Cls, 10 | } 11 | 12 | #[derive(Debug, Clone)] 13 | pub enum PooledOutputType { 14 | Tensor(Tensor), 15 | Array(Array2), 16 | } 17 | 18 | impl PooledOutputType { 19 | pub fn to_tensor(&self) -> Result<&Tensor, anyhow::Error> { 20 | match self { 21 | PooledOutputType::Tensor(tensor) => Ok(tensor), 22 | PooledOutputType::Array(_) => Err(anyhow::anyhow!("Cannot convert Array to Tensor")), 23 | } 24 | } 25 | 26 | pub fn to_array(&self) -> Result<&Array2, anyhow::Error> { 27 | match self { 28 | PooledOutputType::Tensor(_) => Err(anyhow::anyhow!("Cannot convert Tensor to Array")), 29 | PooledOutputType::Array(array) => Ok(array), 30 | } 31 | } 32 | } 33 | impl From for PooledOutputType { 34 | fn from(value: Tensor) -> Self { 35 | PooledOutputType::Tensor(value) 36 | } 37 | } 38 | 39 | impl From> for PooledOutputType { 40 | fn from(value: Array2) -> Self { 41 | PooledOutputType::Array(value) 42 | } 43 | } 44 | pub enum ModelOutput { 45 | Tensor(Tensor), 46 | Array(Array3), 47 | } 48 | 49 | impl Pooling { 50 | pub fn pool( 51 | &self, 52 | output: &ModelOutput, 53 | attention_mask: Option<&PooledOutputType>, 54 | ) -> Result { 55 | match self { 56 | Pooling::Cls => Self::cls(output), 57 | Pooling::Mean => Self::mean(output, attention_mask), 58 | } 59 | } 60 | 61 | fn cls(output: &ModelOutput) -> Result { 62 | match output { 63 | ModelOutput::Tensor(tensor) => tensor 64 | .get_on_dim(1, 0) 65 | .map(PooledOutputType::Tensor) 66 | .map_err(|_| anyhow::anyhow!("Cls of empty tensor")), 67 | ModelOutput::Array(array) => Ok(PooledOutputType::Array( 68 | array.slice(s![.., 0, ..]).to_owned(), 69 | )), 70 | } 71 | } 72 | 73 | fn mean( 74 | output: &ModelOutput, 75 | attention_mask: Option<&PooledOutputType>, 76 | ) -> Result { 77 | match output { 78 | ModelOutput::Tensor(tensor) => { 79 | let attention_mask = if let Some(mask) = attention_mask { 80 | mask.to_tensor()? 81 | } else { 82 | &tensor.ones_like()? 83 | }; 84 | 85 | let expanded_mask = attention_mask 86 | .unsqueeze(2)? 87 | .expand(&[tensor.dim(0)?, tensor.dim(1)?, tensor.dim(2)?])? 88 | .to_dtype(tensor.dtype())?; 89 | 90 | let mask_sum = expanded_mask.sum_all()?.clamp(1e-10, f32::MAX)?; 91 | 92 | let result = tensor 93 | .mul(&expanded_mask)? 94 | .sum(1)? 95 | .broadcast_div(&mask_sum)?; 96 | 97 | Ok(PooledOutputType::Tensor(result)) 98 | } 99 | ModelOutput::Array(output) => { 100 | let attention_mask = attention_mask 101 | .ok_or_else(|| { 102 | anyhow::anyhow!("Attention mask required for Mean pooling output") 103 | })? 104 | .to_array()?; 105 | 106 | let mask_3d = attention_mask.view().insert_axis(Axis(2)); 107 | 108 | let mask_sum = mask_3d.iter().sum::(); 109 | 110 | let result = output 111 | .view() 112 | .mul(&mask_3d) 113 | .sum_axis(Axis(1)) 114 | .mapv(|x| x / mask_sum.clamp(1e-10, f32::MAX)); 115 | 116 | Ok(PooledOutputType::Array(result.to_owned())) 117 | } 118 | } 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /rust/src/embeddings/mod.rs: -------------------------------------------------------------------------------- 1 | //! This module contains the different embedding models that can be used to generate embeddings for the text data. 2 | 3 | use std::{collections::HashMap, rc::Rc}; 4 | 5 | use candle_core::{Device, Tensor}; 6 | use embed::{EmbedData, Embedder, EmbeddingResult}; 7 | 8 | use crate::file_processor::audio::audio_processor::Segment; 9 | 10 | pub mod cloud; 11 | pub mod embed; 12 | pub mod local; 13 | pub mod utils; 14 | 15 | use rayon::prelude::*; 16 | pub fn get_text_metadata( 17 | encodings: &Rc>, 18 | text_batch: &[&str], 19 | metadata: &Option>, 20 | ) -> anyhow::Result> { 21 | let final_embeddings = encodings 22 | .par_iter() 23 | .zip(text_batch) 24 | .map(|(data, text)| EmbedData::new(data.clone(), Some(text.to_string()), metadata.clone())) 25 | .collect::>(); 26 | Ok(final_embeddings) 27 | } 28 | 29 | pub fn get_audio_metadata>( 30 | encodings: Vec, 31 | segments: Vec, 32 | audio_file: T, 33 | ) -> Result, anyhow::Error> { 34 | let final_embeddings = encodings 35 | .iter() 36 | .enumerate() 37 | .map(|(i, data)| { 38 | let mut metadata = HashMap::new(); 39 | metadata.insert("start".to_string(), segments[i].start.to_string()); 40 | metadata.insert( 41 | "end".to_string(), 42 | (segments[i].start + segments[i].duration).to_string(), 43 | ); 44 | metadata.insert( 45 | "file_name".to_string(), 46 | audio_file.as_ref().to_str().unwrap().to_string(), 47 | ); 48 | metadata.insert("text".to_string(), segments[i].dr.text.clone()); 49 | EmbedData::new( 50 | data.clone(), 51 | Some(segments[i].dr.text.clone()), 52 | Some(metadata), 53 | ) 54 | }) 55 | .collect::>(); 56 | Ok(final_embeddings) 57 | } 58 | 59 | pub fn text_batch_from_audio(segments: &[Segment]) -> Vec<&str> { 60 | segments 61 | .iter() 62 | .map(|segment| segment.dr.text.as_str()) 63 | .collect() 64 | } 65 | 66 | pub async fn embed_audio>( 67 | embedder: &Embedder, 68 | segments: Vec, 69 | audio_file: T, 70 | batch_size: Option, 71 | ) -> Result, anyhow::Error> { 72 | let text_batch = text_batch_from_audio(&segments); 73 | let encodings = embedder.embed(&text_batch, batch_size, None).await?; 74 | get_audio_metadata(encodings, segments, audio_file) 75 | } 76 | 77 | pub fn normalize_l2(v: &Tensor) -> candle_core::Result { 78 | v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?) 79 | } 80 | 81 | pub fn select_device() -> Device { 82 | #[cfg(feature = "metal")] 83 | { 84 | Device::new_metal(0).unwrap_or(Device::Cpu) 85 | } 86 | #[cfg(all(not(feature = "metal"), feature = "cuda"))] 87 | { 88 | Device::cuda_if_available(0).unwrap_or(Device::Cpu) 89 | } 90 | #[cfg(not(any(feature = "metal", feature = "cuda")))] 91 | { 92 | Device::Cpu 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /rust/src/embeddings/utils.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Error as E; 2 | use candle_core::{Device, Tensor}; 3 | use ndarray::Array2; 4 | use tokenizers::Tokenizer; 5 | 6 | pub fn tokenize_batch( 7 | tokenizer: &Tokenizer, 8 | text_batch: &[&str], 9 | device: &Device, 10 | ) -> anyhow::Result<(Tensor, Tensor)> { 11 | let tokens = tokenizer 12 | .encode_batch(text_batch.to_vec(), true) 13 | .map_err(E::msg)?; 14 | let token_ids = tokens 15 | .iter() 16 | .map(|tokens| { 17 | let tokens = tokens.get_ids().to_vec(); 18 | Tensor::new(tokens.as_slice(), device) 19 | }) 20 | .collect::>>()?; 21 | let attention_mask = tokens 22 | .iter() 23 | .map(|tokens| { 24 | let tokens = tokens.get_attention_mask().to_vec(); 25 | Tensor::new(tokens.as_slice(), device) 26 | }) 27 | .collect::>>()?; 28 | 29 | Ok(( 30 | Tensor::stack(&token_ids, 0)?, 31 | Tensor::stack(&attention_mask, 0)?, 32 | )) 33 | } 34 | 35 | pub fn get_attention_mask( 36 | tokenizer: &Tokenizer, 37 | text_batch: &[String], 38 | device: &Device, 39 | ) -> anyhow::Result { 40 | let tokens = tokenizer 41 | .encode_batch(text_batch.to_vec(), true) 42 | .map_err(E::msg)?; 43 | 44 | let attention_mask = tokens 45 | .iter() 46 | .map(|tokens| { 47 | let tokens = tokens.get_attention_mask().to_vec(); 48 | Tensor::new(tokens.as_slice(), device) 49 | }) 50 | .collect::>>()?; 51 | Ok(Tensor::stack(&attention_mask, 0)?) 52 | } 53 | 54 | pub fn get_attention_mask_ndarray( 55 | tokenizer: &Tokenizer, 56 | text_batch: &[&str], 57 | ) -> anyhow::Result> { 58 | let attention_mask = tokenizer 59 | .encode_batch(text_batch.to_vec(), true) 60 | .map_err(E::msg)? 61 | .iter() 62 | .map(|tokens| { 63 | tokens 64 | .get_attention_mask() 65 | .iter() 66 | .map(|&id| id as i64) 67 | .collect::>() 68 | }) 69 | .collect::>>(); 70 | 71 | let attention_mask_array = Array2::from_shape_vec( 72 | (attention_mask.len(), attention_mask[0].len()), 73 | attention_mask.into_iter().flatten().collect::>(), 74 | ) 75 | .unwrap(); 76 | Ok(attention_mask_array) 77 | } 78 | 79 | pub fn tokenize_batch_ndarray( 80 | tokenizer: &Tokenizer, 81 | text_batch: &[&str], 82 | ) -> anyhow::Result<(Array2, Array2)> { 83 | let tokens = tokenizer 84 | .encode_batch(text_batch.to_vec(), true) 85 | .map_err(E::msg)?; 86 | let token_ids = tokens 87 | .iter() 88 | .map(|tokens| { 89 | tokens 90 | .get_ids() 91 | .iter() 92 | .map(|&id| id as i64) 93 | .collect::>() 94 | }) 95 | .collect::>>(); 96 | let attention_mask = tokens 97 | .iter() 98 | .map(|tokens| { 99 | tokens 100 | .get_attention_mask() 101 | .to_vec() 102 | .iter() 103 | .map(|&id| id as i64) 104 | .collect::>() 105 | }) 106 | .collect::>>(); 107 | let token_ids_array = Array2::from_shape_vec( 108 | (token_ids.len(), token_ids[0].len()), 109 | token_ids.into_iter().flatten().collect::>(), 110 | ) 111 | .unwrap(); 112 | let attention_mask_array = Array2::from_shape_vec( 113 | (attention_mask.len(), attention_mask[0].len()), 114 | attention_mask.into_iter().flatten().collect::>(), 115 | ) 116 | .unwrap(); 117 | Ok((token_ids_array, attention_mask_array)) 118 | } 119 | 120 | pub fn get_type_ids_ndarray( 121 | tokenizer: &Tokenizer, 122 | text_batch: &[&str], 123 | ) -> anyhow::Result> { 124 | let token_ids = tokenizer 125 | .encode_batch(text_batch.to_vec(), true) 126 | .map_err(E::msg)? 127 | .iter() 128 | .map(|tokens| { 129 | tokens 130 | .get_type_ids() 131 | .iter() 132 | .map(|&id| id as i64) 133 | .collect::>() 134 | }) 135 | .collect::>>(); 136 | 137 | let token_ids_array = Array2::from_shape_vec( 138 | (token_ids.len(), token_ids[0].len()), 139 | token_ids.into_iter().flatten().collect::>(), 140 | ) 141 | .unwrap(); 142 | Ok(token_ids_array) 143 | } 144 | -------------------------------------------------------------------------------- /rust/src/file_processor/audio/melfilters.bytes: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StarlightSearch/EmbedAnything/82b406a8afbec9507320102f24f4ce169fade329/rust/src/file_processor/audio/melfilters.bytes -------------------------------------------------------------------------------- /rust/src/file_processor/audio/melfilters128.bytes: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StarlightSearch/EmbedAnything/82b406a8afbec9507320102f24f4ce169fade329/rust/src/file_processor/audio/melfilters128.bytes -------------------------------------------------------------------------------- /rust/src/file_processor/audio/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod audio_processor; 2 | pub mod pcm_decode; 3 | -------------------------------------------------------------------------------- /rust/src/file_processor/audio/pcm_decode.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "audio")] 2 | pub mod audio_processing { 3 | use symphonia::core::audio::{AudioBufferRef, Signal}; 4 | use symphonia::core::codecs::{DecoderOptions, CODEC_TYPE_NULL}; 5 | use symphonia::core::conv::FromSample; 6 | 7 | fn conv( 8 | samples: &mut Vec, 9 | data: std::borrow::Cow>, 10 | ) where 11 | T: symphonia::core::sample::Sample, 12 | f32: symphonia::core::conv::FromSample, 13 | { 14 | samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v))) 15 | } 16 | 17 | pub(crate) fn pcm_decode>( 18 | path: P, 19 | ) -> anyhow::Result<(Vec, u32)> { 20 | // Open the media source. 21 | let src = std::fs::File::open(path)?; 22 | 23 | // Create the media source stream. 24 | let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default()); 25 | 26 | // Create a probe hint using the file's extension. [Optional] 27 | let hint = symphonia::core::probe::Hint::new(); 28 | 29 | // Use the default options for metadata and format readers. 30 | let meta_opts: symphonia::core::meta::MetadataOptions = Default::default(); 31 | let fmt_opts: symphonia::core::formats::FormatOptions = Default::default(); 32 | 33 | // Probe the media source. 34 | let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?; 35 | // Get the instantiated format reader. 36 | let mut format = probed.format; 37 | 38 | // Find the first audio track with a known (decodeable) codec. 39 | let track = format 40 | .tracks() 41 | .iter() 42 | .find(|t| t.codec_params.codec != CODEC_TYPE_NULL) 43 | .expect("no supported audio tracks"); 44 | 45 | // Use the default options for the decoder. 46 | let dec_opts: DecoderOptions = Default::default(); 47 | 48 | // Create a decoder for the track. 49 | let mut decoder = symphonia::default::get_codecs() 50 | .make(&track.codec_params, &dec_opts) 51 | .expect("unsupported codec"); 52 | let track_id = track.id; 53 | let sample_rate = track.codec_params.sample_rate.unwrap_or(0); 54 | let mut pcm_data = Vec::new(); 55 | // The decode loop. 56 | while let Ok(packet) = format.next_packet() { 57 | // Consume any new metadata that has been read since the last packet. 58 | while !format.metadata().is_latest() { 59 | format.metadata().pop(); 60 | } 61 | 62 | // If the packet does not belong to the selected track, skip over it. 63 | if packet.track_id() != track_id { 64 | continue; 65 | } 66 | match decoder.decode(&packet)? { 67 | AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)), 68 | AudioBufferRef::U8(data) => conv(&mut pcm_data, data), 69 | AudioBufferRef::U16(data) => conv(&mut pcm_data, data), 70 | AudioBufferRef::U24(data) => conv(&mut pcm_data, data), 71 | AudioBufferRef::U32(data) => conv(&mut pcm_data, data), 72 | AudioBufferRef::S8(data) => conv(&mut pcm_data, data), 73 | AudioBufferRef::S16(data) => conv(&mut pcm_data, data), 74 | AudioBufferRef::S24(data) => conv(&mut pcm_data, data), 75 | AudioBufferRef::S32(data) => conv(&mut pcm_data, data), 76 | AudioBufferRef::F64(data) => conv(&mut pcm_data, data), 77 | } 78 | } 79 | Ok((pcm_data, sample_rate)) 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /rust/src/file_processor/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod audio; 2 | -------------------------------------------------------------------------------- /rust/src/models/colpali.rs: -------------------------------------------------------------------------------- 1 | use candle_core::{Module, Result, Tensor}; 2 | use candle_nn::VarBuilder; 3 | 4 | use super::paligemma; 5 | use candle_nn::{linear, Linear}; 6 | 7 | pub struct Model { 8 | pub model: paligemma::Model, 9 | pub custom_text_projection: Linear, 10 | } 11 | 12 | impl Model { 13 | pub fn new(config: &paligemma::Config, vb: VarBuilder) -> Result { 14 | let model = paligemma::Model::new(config, vb.pp("model"))?; 15 | let custom_text_projection = linear( 16 | config.text_config.hidden_size, 17 | 128, 18 | vb.pp("custom_text_proj"), 19 | )?; 20 | 21 | Ok(Self { 22 | model, 23 | custom_text_projection, 24 | }) 25 | } 26 | 27 | pub fn forward_images(&mut self, pixel_values: &Tensor, input_ids: &Tensor) -> Result { 28 | let outputs = self 29 | .model 30 | .setup_without_projection(pixel_values, input_ids)?; 31 | let outputs = self.custom_text_projection.forward(&outputs)?; 32 | let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?; 33 | Ok(outputs) 34 | } 35 | 36 | pub fn forward_text(&mut self, input_ids: &Tensor) -> Result { 37 | let outputs = self.model.forward_without_projection(input_ids)?; 38 | let outputs = self.custom_text_projection.forward(&outputs)?; 39 | let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?; 40 | Ok(outputs) 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /rust/src/models/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod bert; 2 | pub mod clip; 3 | pub mod colpali; 4 | pub mod gemma; 5 | pub mod jina_bert; 6 | pub mod modernbert; 7 | pub mod paligemma; 8 | pub mod siglip; 9 | pub mod with_tracing; 10 | -------------------------------------------------------------------------------- /rust/src/reranker/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod model; 2 | -------------------------------------------------------------------------------- /rust/src/text_loader.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | collections::HashMap, 3 | fmt::Debug, 4 | fs, 5 | }; 6 | 7 | use anyhow::Error; 8 | use chrono::{DateTime, Local}; 9 | use text_splitter::{Characters, ChunkConfig, TextSplitter}; 10 | 11 | impl Default for TextLoader { 12 | fn default() -> Self { 13 | Self::new(1000, 0.0) 14 | } 15 | } 16 | 17 | #[derive(Debug)] 18 | pub struct TextLoader { 19 | pub splitter: TextSplitter, 20 | } 21 | impl TextLoader { 22 | pub fn new(chunk_size: usize, overlap_ratio: f32) -> Self { 23 | Self { 24 | splitter: TextSplitter::new( 25 | ChunkConfig::new(chunk_size) 26 | .with_overlap(chunk_size * overlap_ratio as usize) 27 | .unwrap(), 28 | ), 29 | } 30 | } 31 | 32 | pub fn get_metadata>( 33 | file: T, 34 | ) -> Result, Error> { 35 | let metadata = fs::metadata(&file).unwrap(); 36 | let mut metadata_map = HashMap::new(); 37 | metadata_map.insert( 38 | "created".to_string(), 39 | format!("{}", DateTime::::from(metadata.created()?)), 40 | ); 41 | metadata_map.insert( 42 | "modified".to_string(), 43 | format!("{}", DateTime::::from(metadata.modified()?)), 44 | ); 45 | 46 | metadata_map.insert( 47 | "file_name".to_string(), 48 | fs::canonicalize(file)?.to_str().unwrap().to_string(), 49 | ); 50 | Ok(metadata_map) 51 | } 52 | } 53 | 54 | #[cfg(test)] 55 | mod tests { 56 | use super::*; 57 | use crate::embeddings::{embed::EmbedImage, local::clip::ClipEmbedder}; 58 | use std::path::PathBuf; 59 | 60 | #[test] 61 | fn test_metadata() { 62 | let file_path = PathBuf::from("../test_files/test.pdf"); 63 | let metadata = TextLoader::get_metadata(file_path.to_str().unwrap()).unwrap(); 64 | 65 | // assert the fields that are present 66 | assert!(metadata.contains_key("created")); 67 | assert!(metadata.contains_key("modified")); 68 | assert!(metadata.contains_key("file_name")); 69 | } 70 | 71 | #[tokio::test] 72 | async fn test_image_embedder() { 73 | let file_path = PathBuf::from("../test_files/clip/cat1.jpg"); 74 | let embedder = ClipEmbedder::default(); 75 | let emb_data = embedder.embed_image(file_path, None).await.unwrap(); 76 | assert_eq!(emb_data.embedding.to_dense().unwrap().len(), 512); 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /test_files/audio/samples_hp0.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StarlightSearch/EmbedAnything/82b406a8afbec9507320102f24f4ce169fade329/test_files/audio/samples_hp0.wav -------------------------------------------------------------------------------- /test_files/audio/samples_jfk.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StarlightSearch/EmbedAnything/82b406a8afbec9507320102f24f4ce169fade329/test_files/audio/samples_jfk.wav -------------------------------------------------------------------------------- /test_files/bank.txt: -------------------------------------------------------------------------------- 1 | The Bank of Elarian 2 | Nestled in the heart of the bustling city of Elarian, the Bank of Elarian stands as a beacon of financial stability and innovative banking. Founded over a century ago, this bank has grown from a modest community institution into one of the city's most trusted financial centers. Known for its majestic architecture, the building's facade is a blend of classical and modern design, featuring high pillars and sleek glass panels that reflect the city's skyline. 3 | Inside, the Bank of Elarian is equipped with state-of-the-art technology, offering customers a seamless banking experience. From advanced ATMs to virtual financial advisors, the bank has embraced digital transformation while maintaining a personal touch in customer service. Its range of services includes traditional savings and checking accounts, investment advisory, and a highly regarded private banking division catering to high net worth individuals. 4 | The bank is also committed to community development, funding various local projects, and supporting small businesses, reflecting its deep roots in the Elarian community.Sapore di Mare Restaurant - Just a few blocks from the Bank of Elarian, Sapore di Mare offers a culinary journey with its exquisite seafood and Mediterranean cuisine. Its cozy, nautical-themed interior, complete with wooden accents and subtle lighting, creates a warm and inviting atmosphere. 5 | The restaurant's signature dish is the Fruit of the Sea platter, featuring the freshest catch from local fishermen, cooked to perfection with herbs and spices that highlight the natural flavors. The chef, a native of the Mediterranean coast, brings authentic recipes and a passion for seafood to the table, ensuring each dish is a masterpiece.Sapore di Mare is not just known for its food but also for its exceptional service. The staff go above and beyond to create a memorable dining experience, making it a popular destination for both locals and tourists.Elarian Freiseur - A short stroll from Sapore di Mare is Elarian Freiseur, a boutique hair salon known for its chic style and innovative hair treatments. This salon stands out with its modern design, featuring sleek chairs, ambient lighting, and an array of plants that add a touch of greenery and freshness. 6 | The team at Elarian Freiseur is comprised of highly skilled stylists and colorists who are experts in the latest hair trends. They offer personalized consultations to each client, ensuring a customized experience that meets individual style preferences. From classic cuts to avant-garde hair coloring, the salon is a hub for those seeking a transformative hair experience. 7 | Elarian Freiseur also places a high emphasis on using eco-friendly and sustainable hair products, aligning with the city's growing environmental consciousness. This commitment to quality and sustainability has earned it a loyal clientele who appreciate the salon's dedication to both style and the environment. -------------------------------------------------------------------------------- /test_files/clip/cat1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StarlightSearch/EmbedAnything/82b406a8afbec9507320102f24f4ce169fade329/test_files/clip/cat1.jpg -------------------------------------------------------------------------------- /test_files/clip/cat2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StarlightSearch/EmbedAnything/82b406a8afbec9507320102f24f4ce169fade329/test_files/clip/cat2.jpeg -------------------------------------------------------------------------------- /test_files/clip/dog1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StarlightSearch/EmbedAnything/82b406a8afbec9507320102f24f4ce169fade329/test_files/clip/dog1.jpg -------------------------------------------------------------------------------- /test_files/clip/dog2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StarlightSearch/EmbedAnything/82b406a8afbec9507320102f24f4ce169fade329/test_files/clip/dog2.jpeg -------------------------------------------------------------------------------- /test_files/clip/monkey1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StarlightSearch/EmbedAnything/82b406a8afbec9507320102f24f4ce169fade329/test_files/clip/monkey1.jpg -------------------------------------------------------------------------------- /test_files/colpali.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StarlightSearch/EmbedAnything/82b406a8afbec9507320102f24f4ce169fade329/test_files/colpali.pdf -------------------------------------------------------------------------------- /test_files/linear.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StarlightSearch/EmbedAnything/82b406a8afbec9507320102f24f4ce169fade329/test_files/linear.pdf -------------------------------------------------------------------------------- /test_files/test.docx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StarlightSearch/EmbedAnything/82b406a8afbec9507320102f24f4ce169fade329/test_files/test.docx -------------------------------------------------------------------------------- /test_files/test.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 |

My First Heading

6 | 7 |

My first paragraph.

8 | 9 | 10 | -------------------------------------------------------------------------------- /test_files/test.md: -------------------------------------------------------------------------------- 1 | Hello, world! 2 | # How are you 3 | ## I am good 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /test_files/test.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StarlightSearch/EmbedAnything/82b406a8afbec9507320102f24f4ce169fade329/test_files/test.pdf -------------------------------------------------------------------------------- /test_files/test.txt: -------------------------------------------------------------------------------- 1 | This is a test file to see how txt embedding works ! 2 | -------------------------------------------------------------------------------- /tests/model_tests/conftest.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import pytest 3 | from embed_anything import ( 4 | Adapter, 5 | AudioDecoderModel, 6 | EmbedData, 7 | EmbeddingModel, 8 | WhichModel, 9 | ColpaliModel, 10 | ) 11 | 12 | from embed_anything import ONNXModel 13 | 14 | 15 | @pytest.fixture 16 | def clip_model() -> EmbeddingModel: 17 | model = EmbeddingModel.from_pretrained_hf( 18 | WhichModel.Clip, model_id="openai/clip-vit-base-patch32", revision="refs/pr/15" 19 | ) 20 | return model 21 | 22 | 23 | @pytest.fixture 24 | def test_files_directory() -> str: 25 | return "test_files" 26 | 27 | 28 | @pytest.fixture 29 | def test_pdf_file(test_files_directory) -> str: 30 | return f"{test_files_directory}/test.pdf" 31 | 32 | 33 | @pytest.fixture 34 | def test_txt_file(test_files_directory) -> str: 35 | return f"{test_files_directory}/test.txt" 36 | 37 | 38 | @pytest.fixture 39 | def test_image_file(test_files_directory) -> str: 40 | return f"{test_files_directory}/clip/monkey1.jpg" 41 | 42 | 43 | @pytest.fixture 44 | def test_audio_file(test_files_directory) -> str: 45 | return f"{test_files_directory}/audio/samples_jfk.wav" 46 | 47 | 48 | @pytest.fixture 49 | def test_image_directory(test_files_directory) -> str: 50 | return f"{test_files_directory}/clip" 51 | 52 | 53 | @pytest.fixture 54 | def test_text_directory(test_files_directory) -> str: 55 | return f"{test_files_directory}" 56 | 57 | 58 | @pytest.fixture 59 | def jina_model() -> EmbeddingModel: 60 | model = EmbeddingModel.from_pretrained_hf( 61 | WhichModel.Jina, model_id="jinaai/jina-embeddings-v2-small-en", revision="main" 62 | ) 63 | return model 64 | 65 | 66 | @pytest.fixture 67 | def bert_model() -> EmbeddingModel: 68 | model = EmbeddingModel.from_pretrained_hf( 69 | WhichModel.Bert, 70 | model_id="sentence-transformers/all-MiniLM-L6-v2", 71 | revision="main", 72 | ) 73 | return model 74 | 75 | 76 | @pytest.fixture 77 | def audio_decoder() -> AudioDecoderModel: 78 | model = AudioDecoderModel.from_pretrained_hf( 79 | model_id="openai/whisper-tiny", revision="main", model_type="tiny-en" 80 | ) 81 | return model 82 | 83 | 84 | @pytest.fixture 85 | def openai_model() -> EmbeddingModel: 86 | model = EmbeddingModel.from_pretrained_cloud( 87 | WhichModel.OpenAI, model_id="text-embedding-3-small" 88 | ) 89 | return model 90 | 91 | 92 | @pytest.fixture 93 | def onnx_model() -> EmbeddingModel: 94 | model = EmbeddingModel.from_pretrained_onnx( 95 | WhichModel.Bert, ONNXModel.AllMiniLML6V2Q 96 | ) 97 | return model 98 | 99 | 100 | @pytest.fixture 101 | def colpali_onnx_model() -> ColpaliModel: 102 | model = ColpaliModel.from_pretrained_onnx( 103 | model_id="akshayballal/colpali-v1.2-merged-onnx" 104 | ) 105 | return model 106 | 107 | 108 | @pytest.fixture 109 | def colpali_model() -> ColpaliModel: 110 | model = ColpaliModel.from_pretrained("vidore/colpali-v1.2-merged") 111 | return model 112 | 113 | 114 | class DummyAdapter(Adapter): 115 | 116 | def create_index(self, dimension: int, metric: str, index_name: str, **kwargs): 117 | pass 118 | 119 | def delete_index(self, index_name: str): 120 | pass 121 | 122 | def convert(self, data: List[EmbedData]) -> List[EmbedData]: 123 | return data 124 | 125 | def upsert(self, data: List[EmbedData]) -> None: 126 | data = self.convert(data) 127 | return 1 128 | 129 | 130 | @pytest.fixture 131 | def dummy_adapter() -> DummyAdapter: 132 | return DummyAdapter("dummy") 133 | -------------------------------------------------------------------------------- /tests/model_tests/test_adapter.py: -------------------------------------------------------------------------------- 1 | import embed_anything 2 | import pytest 3 | 4 | 5 | def test_adapter_upsert_call_file( 6 | bert_model, dummy_adapter, test_pdf_file, test_txt_file 7 | ): 8 | assert ( 9 | embed_anything.embed_file( 10 | test_pdf_file, embedder=bert_model, adapter=dummy_adapter 11 | ) 12 | is None 13 | ) 14 | assert ( 15 | embed_anything.embed_file( 16 | test_txt_file, embedder=bert_model, adapter=dummy_adapter 17 | ) 18 | is None 19 | ) 20 | 21 | 22 | def test_adapter_upsert_call_directory(bert_model, dummy_adapter, test_files_directory): 23 | assert ( 24 | embed_anything.embed_directory( 25 | test_files_directory, embedder=bert_model, adapter=dummy_adapter 26 | ) 27 | is None 28 | ) 29 | -------------------------------------------------------------------------------- /tests/model_tests/test_audio.py: -------------------------------------------------------------------------------- 1 | from embed_anything import AudioDecoderModel, EmbeddingModel, embed_audio_file 2 | import pytest 3 | 4 | 5 | def test_audio_decoder(audio_decoder: AudioDecoderModel): 6 | assert audio_decoder is not None 7 | 8 | 9 | def test_audio_embed_file( 10 | audio_decoder: AudioDecoderModel, bert_model: EmbeddingModel, test_audio_file 11 | ): 12 | assert audio_decoder is not None 13 | assert bert_model is not None 14 | data = embed_audio_file(test_audio_file, audio_decoder, bert_model) 15 | assert data is not None 16 | assert len(data) == 1 17 | assert data[0].embedding is not None 18 | assert len(data[0].embedding) == 384 19 | -------------------------------------------------------------------------------- /tests/model_tests/test_bert.py: -------------------------------------------------------------------------------- 1 | from embed_anything import ( 2 | EmbeddingModel, 3 | TextEmbedConfig, 4 | WhichModel, 5 | embed_query, 6 | embed_file, 7 | embed_directory, 8 | ONNXModel, 9 | ) 10 | 11 | import os 12 | import pytest 13 | import tempfile 14 | import itertools 15 | 16 | # Global test parameters 17 | MODEL_FIXTURES = ["bert_model", "onnx_model"] 18 | CONFIGS = [None, TextEmbedConfig(batch_size=32, chunk_size=1000)] 19 | ALL_COMBINATIONS = list(itertools.product(MODEL_FIXTURES, CONFIGS)) 20 | 21 | # Define common parametrize decorator 22 | model_fixture_parametrize = pytest.mark.parametrize("model_fixture", MODEL_FIXTURES) 23 | model_and_config_parametrize = pytest.mark.parametrize( 24 | "model_fixture,config", ALL_COMBINATIONS 25 | ) 26 | 27 | 28 | @model_and_config_parametrize 29 | def test_bert_model_file(model_fixture, config, test_pdf_file, request): 30 | model = request.getfixturevalue(model_fixture) 31 | data = model.embed_file(test_pdf_file, config) 32 | path = os.path.abspath(test_pdf_file) 33 | 34 | assert len(data) > 0 35 | assert data[0].embedding is not None 36 | assert len(data[0].embedding) == 384 37 | assert data[0].metadata["file_name"] == path 38 | 39 | 40 | def test_bert_model_creation(): 41 | 42 | model = EmbeddingModel.from_pretrained_hf( 43 | WhichModel.Bert, 44 | model_id="sentence-transformers/all-MiniLM-L6-v2", 45 | revision="main", 46 | ) 47 | assert model is not None 48 | 49 | 50 | def test_onnx_model_creation(): 51 | model = EmbeddingModel.from_pretrained_onnx( 52 | WhichModel.Bert, ONNXModel.AllMiniLML6V2Q 53 | ) 54 | assert model is not None 55 | 56 | 57 | @model_fixture_parametrize 58 | def test_bert_model_query(model_fixture, request): 59 | model = request.getfixturevalue(model_fixture) 60 | data = embed_query(["Photo of a monkey?"], model) 61 | assert len(data) == 1 62 | assert data[0].embedding is not None 63 | assert len(data[0].embedding) == 384 64 | 65 | 66 | @model_and_config_parametrize 67 | def test_bert_model_directory(model_fixture, config, test_text_directory, request): 68 | model = request.getfixturevalue(model_fixture) 69 | data = embed_directory(test_text_directory, model, config=config) 70 | assert data[0].embedding is not None 71 | assert len(data[0].embedding) == 384 72 | 73 | 74 | @model_fixture_parametrize 75 | def test_bert_model_empty_query(model_fixture, request): 76 | model = request.getfixturevalue(model_fixture) 77 | data = embed_query([""], model) 78 | assert len(data) == 1 79 | assert data[0].embedding is not None 80 | assert len(data[0].embedding) == 384 81 | 82 | 83 | @model_fixture_parametrize 84 | def test_bert_model_long_query(model_fixture, request): 85 | model = request.getfixturevalue(model_fixture) 86 | long_text = " ".join(["long"] * 1000) 87 | data = embed_query([long_text], model) 88 | assert len(data) == 1 89 | assert data[0].embedding is not None 90 | assert len(data[0].embedding) == 384 91 | 92 | 93 | def test_bert_model_non_ascii_query(bert_model): 94 | non_ascii_text = "こんにちは世界" 95 | data = embed_query([non_ascii_text], bert_model) 96 | assert len(data) == 1 97 | assert data[0].embedding is not None 98 | assert len(data[0].embedding) == 384 99 | 100 | 101 | def test_bert_model_nonexistent_file(bert_model): 102 | with pytest.raises(FileNotFoundError): 103 | embed_file("nonexistent_file.txt", bert_model) 104 | 105 | 106 | def test_bert_model_empty_directory(bert_model, tmp_path): 107 | empty_dir = tmp_path / "empty_dir" 108 | empty_dir.mkdir() 109 | data = embed_directory(str(empty_dir), bert_model) 110 | assert len(data) == 0 111 | 112 | 113 | def test_bert_model_unsupported_file_type(bert_model, tmp_path): 114 | 115 | # Create a file with an unsupported extension 116 | with open(tmp_path / "unsupported.mp3", "w") as f: 117 | f.write("This is a test file") 118 | 119 | with pytest.raises(ValueError): 120 | embed_file(str(tmp_path / "unsupported.mp3"), bert_model) 121 | -------------------------------------------------------------------------------- /tests/model_tests/test_clip.py: -------------------------------------------------------------------------------- 1 | from embed_anything import ( 2 | EmbeddingModel, 3 | TextEmbedConfig, 4 | WhichModel, 5 | embed_query, 6 | embed_file, 7 | embed_directory, 8 | embed_image_directory, 9 | ) 10 | import pytest 11 | import os 12 | 13 | 14 | def test_clip_model_creation(): 15 | model = EmbeddingModel.from_pretrained_hf( 16 | WhichModel.Clip, model_id="openai/clip-vit-base-patch32", revision="refs/pr/15" 17 | ) 18 | 19 | assert model is not None 20 | 21 | model = EmbeddingModel.from_pretrained_hf( 22 | WhichModel.Clip, 23 | model_id="openai/clip-vit-base-patch32", 24 | ) 25 | 26 | assert model is not None 27 | 28 | 29 | def test_clip_model_query(clip_model): 30 | 31 | data = embed_query(["Photo of a monkey?"], clip_model) 32 | assert len(data) == 1 33 | assert data[0].embedding is not None 34 | assert len(data[0].embedding) == 512 35 | 36 | 37 | def test_clip_model_file(clip_model, test_image_file): 38 | 39 | data = embed_file(test_image_file, clip_model) 40 | 41 | assert data[0].embedding is not None 42 | assert len(data[0].embedding) == 512 43 | 44 | 45 | def test_clip_model_directory(clip_model): 46 | 47 | data = embed_image_directory("test_files/clip", clip_model) 48 | assert len(data) == 5 49 | assert data[0].embedding is not None 50 | assert len(data[0].embedding) == 512 51 | -------------------------------------------------------------------------------- /tests/model_tests/test_colpali.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from embed_anything import ColpaliModel 3 | 4 | 5 | @pytest.mark.parametrize("model_fixture", ["colpali_model", "colpali_onnx_model"]) 6 | def test_colpali_model_file(model_fixture, test_pdf_file, request): 7 | model: ColpaliModel = request.getfixturevalue(model_fixture) 8 | data = model.embed_file(test_pdf_file, batch_size=1) 9 | assert len(data) == 1 10 | -------------------------------------------------------------------------------- /tests/model_tests/test_jina.py: -------------------------------------------------------------------------------- 1 | from embed_anything import ( 2 | EmbeddingModel, 3 | WhichModel, 4 | embed_query, 5 | embed_file, 6 | embed_directory, 7 | ) 8 | 9 | import os 10 | 11 | 12 | def test_jina_model_creation(): 13 | 14 | model = EmbeddingModel.from_pretrained_hf( 15 | WhichModel.Jina, 16 | model_id="jinaai/jina-embeddings-v2-small-en", 17 | revision="main", 18 | ) 19 | assert model is not None 20 | 21 | 22 | def test_jina_model_query(jina_model): 23 | 24 | data = embed_query(["Photo of a monkey?"], jina_model) 25 | assert len(data) == 1 26 | assert data[0].embedding is not None 27 | assert len(data[0].embedding) == 512 28 | 29 | 30 | def test_jina_model_file(jina_model): 31 | 32 | data = embed_file("test_files/test.pdf", jina_model) 33 | path = os.path.abspath("test_files/test.pdf") 34 | assert data[0].embedding is not None 35 | assert len(data[0].embedding) == 512 36 | 37 | 38 | def test_jina_model_directory(jina_model): 39 | 40 | data = embed_directory("test_files", jina_model) 41 | assert data[0].embedding is not None 42 | assert len(data[0].embedding) == 512 43 | -------------------------------------------------------------------------------- /tests/model_tests/test_openai.py: -------------------------------------------------------------------------------- 1 | from embed_anything import TextEmbedConfig, embed_directory, embed_file, embed_query 2 | import pytest 3 | 4 | 5 | def test_openai_model_file(openai_model, test_pdf_file): 6 | data = embed_file(test_pdf_file, openai_model) 7 | assert data[0].embedding is not None 8 | assert len(data[0].embedding) == 1536 9 | 10 | 11 | @pytest.mark.parametrize( 12 | "config", [TextEmbedConfig(batch_size=512, chunk_size=1000, buffer_size=512)] 13 | ) 14 | def test_openai_model_directory(openai_model, config, test_files_directory): 15 | data = embed_directory(test_files_directory, openai_model, config=config) 16 | assert data[0].embedding is not None 17 | assert len(data[0].embedding) == 1536 18 | 19 | 20 | def test_openai_model_query(openai_model): 21 | data = embed_query(["Hello world"], openai_model) 22 | assert data[0].embedding is not None 23 | assert len(data[0].embedding) == 1536 24 | --------------------------------------------------------------------------------