├── .dagger ├── .gitattributes ├── .gitignore ├── dagger.json ├── pyproject.toml ├── src │ └── semantic_router_ci │ │ ├── __init__.py │ │ └── main.py └── uv.lock ├── .env.example ├── .github └── workflows │ ├── conventional_commits.yml │ ├── diff.yml │ ├── docs.yml │ ├── lint.yml │ ├── pr_agent.yml │ ├── release.yml │ ├── test.yml │ ├── triggers_merge.yml │ └── triggers_pr.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .pydoc-markdown.yml ├── .python-version ├── CLAUDE.md ├── CONTRIBUTING.md ├── LICENSE ├── Makefile ├── README.md ├── bedrock.ipynb ├── commitlint.config.js ├── coverage.xml ├── diff-config.yaml ├── docs ├── 00-introduction.ipynb ├── 01-save-load-from-file.ipynb ├── 02-dynamic-routes.ipynb ├── 03-basic-langchain-agent.ipynb ├── 05-local-execution.ipynb ├── 06-threshold-optimization.ipynb ├── 07-multi-modal.ipynb ├── 08-async-dynamic-routes.ipynb ├── 09-route-filter.ipynb ├── Makefile ├── encoders │ ├── aurelio-bm25.ipynb │ ├── bedrock.ipynb │ ├── cohere.ipynb │ ├── fastembed.ipynb │ ├── google.ipynb │ ├── huggingface-endpoint.ipynb │ ├── huggingface.ipynb │ ├── jina-encoder.ipynb │ ├── local-encoder.ipynb │ ├── mistral-encoder.ipynb │ ├── nvidia_nim-encoder.ipynb │ ├── openai-embed-3.ipynb │ ├── openai-encoder.ipynb │ ├── vision-transformer.ipynb │ └── voyage-encoder.ipynb ├── examples │ ├── function_calling.ipynb │ ├── hybrid-chat-guardrails.ipynb │ ├── hybrid-router.ipynb │ ├── ollama-local-execution.ipynb │ ├── pinecone-and-scaling.ipynb │ └── pinecone-hybrid.ipynb ├── get-started │ ├── introduction.md │ └── quickstart.md ├── indexes │ ├── local.ipynb │ ├── pinecone-async.ipynb │ ├── pinecone-local.ipynb │ ├── pinecone-sync-routes.ipynb │ ├── pinecone.ipynb │ ├── pinecone_async.ipynb │ ├── postgres-sync.ipynb │ ├── postgres │ │ ├── Dockerfile │ │ ├── postgres.compose.yaml │ │ └── postgres.ipynb │ └── qdrant.ipynb ├── integrations │ ├── agents-sdk │ │ ├── hybrid-router-guardrails.ipynb │ │ └── semantic-router-guardrail.ipynb │ ├── graphai │ │ └── sparse-threshold-optimization-guardrail-graphai.ipynb │ └── pydantic-ai │ │ └── chatbot-with-guardrails.ipynb ├── make.bat └── user-guide │ ├── changelog.md │ ├── components │ ├── encoders.md │ ├── indexes.md │ └── routers.md │ ├── concepts │ ├── architecture.md │ └── overview.md │ ├── features │ ├── dynamic-routes.md │ ├── route-filter.md │ ├── sync.md │ └── threshold-optimization.md │ └── guides │ ├── configuration.md │ ├── local-execution.md │ ├── migration-to-v1.md │ ├── save-load-from-file.md │ └── semantic-router.md ├── pyproject.toml ├── replace.py ├── semantic_router ├── __init__.py ├── encoders │ ├── __init__.py │ ├── aurelio.py │ ├── azure_openai.py │ ├── base.py │ ├── bedrock.py │ ├── bm25.py │ ├── clip.py │ ├── cohere.py │ ├── encode_input_type.py │ ├── fastembed.py │ ├── google.py │ ├── huggingface.py │ ├── jina.py │ ├── litellm.py │ ├── local.py │ ├── mistral.py │ ├── nvidia_nim.py │ ├── ollama.py │ ├── openai.py │ ├── tfidf.py │ ├── vit.py │ └── voyage.py ├── index │ ├── __init__.py │ ├── base.py │ ├── hybrid_local.py │ ├── local.py │ ├── pinecone.py │ ├── postgres.py │ └── qdrant.py ├── linear.py ├── llms │ ├── __init__.py │ ├── base.py │ ├── cohere.py │ ├── grammars │ │ └── json.gbnf │ ├── llamacpp.py │ ├── mistral.py │ ├── ollama.py │ ├── openai.py │ ├── openrouter.py │ └── zure.py ├── py.typed ├── route.py ├── routers │ ├── __init__.py │ ├── base.py │ ├── hybrid.py │ └── semantic.py ├── schema.py ├── tokenizers.py └── utils │ ├── __init__.py │ ├── defaults.py │ ├── function_call.py │ └── logger.py ├── tests ├── functional │ ├── encoders │ │ └── test_bm25_functional.py │ └── test_linear.py ├── integration │ ├── 57640.4032.txt │ ├── encoders │ │ └── test_openai_integration.py │ └── test_router_integration.py └── unit │ ├── encoders │ ├── test_azure.py │ ├── test_base.py │ ├── test_bedrock.py │ ├── test_bm25.py │ ├── test_clip.py │ ├── test_fastembed.py │ ├── test_google.py │ ├── test_hfendpointencoder.py │ ├── test_huggingface.py │ ├── test_lite_encoders.py │ ├── test_local.py │ ├── test_ollama.py │ ├── test_openai.py │ ├── test_sparse_sentence_transformer.py │ ├── test_tfidf.py │ └── test_vit.py │ ├── llms │ ├── test_llm_azure_openai.py │ ├── test_llm_base.py │ ├── test_llm_cohere.py │ ├── test_llm_llamacpp.py │ ├── test_llm_mistral.py │ ├── test_llm_ollama.py │ ├── test_llm_openai.py │ └── test_llm_openrouter.py │ ├── test_function_schema.py │ ├── test_route.py │ ├── test_router.py │ ├── test_schema.py │ ├── test_sync.py │ └── test_tokenizers.py └── uv.lock /.dagger/.gitattributes: -------------------------------------------------------------------------------- 1 | /sdk/** linguist-generated 2 | -------------------------------------------------------------------------------- /.dagger/.gitignore: -------------------------------------------------------------------------------- 1 | /.venv 2 | /**/__pycache__ 3 | /sdk 4 | /.env 5 | -------------------------------------------------------------------------------- /.dagger/dagger.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "semantic-router", 3 | "engineVersion": "v0.18.12", 4 | "sdk": { 5 | "source": "python" 6 | } 7 | } 8 | -------------------------------------------------------------------------------- /.dagger/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "semantic-router-ci" 3 | version = "0.1.0" 4 | requires-python = ">=3.12" 5 | dependencies = ["dagger-io", "pytest", "pytest-timeout"] 6 | 7 | [project.entry-points."dagger.mod"] 8 | main_object = 'semantic_router_ci:SemanticRouter' 9 | 10 | [build-system] 11 | requires = ["hatchling==1.25.0"] 12 | build-backend = "hatchling.build" 13 | 14 | [tool.uv.sources] 15 | dagger-io = { path = "sdk", editable = true } 16 | -------------------------------------------------------------------------------- /.dagger/src/semantic_router_ci/__init__.py: -------------------------------------------------------------------------------- 1 | """A generated module for HelloDagger functions 2 | 3 | This module has been generated via dagger init and serves as a reference to 4 | basic module structure as you get started with Dagger. 5 | 6 | Two functions have been pre-created. You can modify, delete, or add to them, 7 | as needed. They demonstrate usage of arguments and return types using simple 8 | echo and grep commands. The functions can be called from the dagger CLI or 9 | from one of the SDKs. 10 | 11 | The first line in this comment block is a short description line and the 12 | rest is a long description with more detail on the module's purpose or usage, 13 | if appropriate. All modules should have a short description. 14 | """ 15 | 16 | from .main import SemanticRouter as SemanticRouter 17 | -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | COHERE_API_KEY="" 2 | OPENAI_API_KEY="" 3 | PINECONE_API_KEY="" 4 | AURELIO_API_KEY="" 5 | 6 | POSTGRES_HOST="localhost" # your host 7 | POSTGRES_PORT="5432" # your port 8 | POSTGRES_DB="routes_db" # your database name 9 | POSTGRES_USER="postgres" # your username 10 | POSTGRES_PASSWORD="password" # your password -------------------------------------------------------------------------------- /.github/workflows/conventional_commits.yml: -------------------------------------------------------------------------------- 1 | # Enforces conventional commits on pull requests 2 | # Ref: https://github.com/marketplace/actions/conventional-pull-request 3 | name: Conventional commits pull request 4 | on: 5 | pull_request: 6 | branches: [main] 7 | types: [opened, edited, synchronize] 8 | jobs: 9 | lint-pr: 10 | if: "${{ !contains(github.event.pull_request.title, 'chore(main): release') }}" 11 | name: Lint PR 12 | runs-on: ubuntu-latest 13 | steps: 14 | - name: Checkout repository 15 | uses: actions/checkout@v4 16 | 17 | - name: Validate 18 | uses: CondeNast/conventional-pull-request-action@v0.2.0 19 | env: 20 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 21 | with: 22 | commitlintRulesPath: "./commitlint.config.js" 23 | # if the PR contains a single commit, fail if the commit message and the PR title do not match 24 | commitTitleMatch: "false" # default: 'true' 25 | # if you squash merge PRs and enabled "Default to PR title for squash merge commits", you can disable all linting of commits 26 | ignoreCommits: "false" # default: 'false' 27 | -------------------------------------------------------------------------------- /.github/workflows/diff.yml: -------------------------------------------------------------------------------- 1 | name: Diff 2 | on: 3 | workflow_call: 4 | inputs: 5 | config: 6 | type: string 7 | required: true 8 | description: 'YAML file containing modules to track' 9 | outputs: 10 | diff: 11 | description: "diff" 12 | value: "${{ jobs.diff.outputs.diff }}" 13 | tags: 14 | description: "tags" 15 | value: "${{ jobs.diff.outputs.tags }}" 16 | modules: 17 | description: "modules" 18 | value: "${{ jobs.diff.outputs.modules }}" 19 | changed: 20 | description: "changed" 21 | value: "${{ jobs.diff.outputs.changed }}" 22 | 23 | jobs: 24 | diff: 25 | name: Compare 26 | runs-on: ubuntu-latest 27 | outputs: 28 | diff: ${{ steps.run.outputs.diff }} 29 | tags: ${{ steps.run.outputs.tags }} 30 | modules: ${{ steps.run.outputs.modules }} 31 | changed: ${{ steps.run.outputs.changed }} 32 | steps: 33 | 34 | - name: Checkout 35 | uses: actions/checkout@v4 36 | 37 | - id: run 38 | name: Diff action 39 | uses: aurelio-labs/diff-action@0.2.0 40 | with: 41 | token: ${{ secrets.GITHUB_TOKEN }} 42 | config: ${{ inputs.config }} 43 | 44 | - name: Print output 45 | run: echo '${{ toJSON(steps.run.outputs) }}' | jq . 46 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: lint 2 | 3 | on: 4 | workflow_call: 5 | 6 | 7 | jobs: 8 | build: 9 | runs-on: ubuntu-latest 10 | strategy: 11 | matrix: 12 | python-version: 13 | - "3.13" 14 | steps: 15 | - uses: actions/checkout@v3 16 | - name: Install uv 17 | uses: astral-sh/setup-uv@v5 18 | with: 19 | enable-cache: true 20 | cache-dependency-glob: "uv.lock" 21 | python-version: ${{ matrix.python-version }} 22 | - name: Install Dependencies 23 | run: | 24 | uv sync --extra dev 25 | - name: Run Lint 26 | run: | 27 | make lint 28 | 29 | -------------------------------------------------------------------------------- /.github/workflows/pr_agent.yml: -------------------------------------------------------------------------------- 1 | on: 2 | pull_request: 3 | issue_comment: 4 | jobs: 5 | pr_agent_job: 6 | runs-on: ubuntu-latest 7 | permissions: 8 | pull-requests: write 9 | issues: write 10 | contents: write 11 | name: Run pr agent on every pull request, respond to user comments 12 | steps: 13 | - name: PR Agent action step 14 | id: pragent 15 | uses: Codium-ai/pr-agent@main 16 | env: 17 | OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} 18 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 19 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | tags: 6 | - '*' 7 | 8 | jobs: 9 | build: 10 | runs-on: ubuntu-latest 11 | strategy: 12 | matrix: 13 | python-version: [ '3.10' ] 14 | steps: 15 | - uses: actions/checkout@v2 16 | - name: Install uv 17 | uses: astral-sh/setup-uv@v5 18 | with: 19 | enable-cache: true 20 | cache-dependency-glob: "uv.lock" 21 | python-version: ${{ matrix.python-version }} 22 | - name: Install dependencies 23 | run: uv sync --extra dev 24 | - name: Build 25 | run: uv build 26 | 27 | publish: 28 | needs: build 29 | runs-on: ubuntu-latest 30 | strategy: 31 | matrix: 32 | python-version: [ '3.10' ] 33 | steps: 34 | - uses: actions/checkout@v2 35 | - name: Install uv 36 | uses: astral-sh/setup-uv@v5 37 | with: 38 | enable-cache: true 39 | cache-dependency-glob: "uv.lock" 40 | python-version: ${{ matrix.python-version }} 41 | - name: Install dependencies 42 | run: uv sync --extra dev 43 | - name: Build 44 | run: uv build 45 | - name: Publish to PyPI 46 | run: | 47 | uv publish --username "__token__" --password "$PYPI_API_TOKEN" 48 | env: 49 | PYPI_API_TOKEN: ${{ secrets.PYPI_API_TOKEN }} 50 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | test_scope: 7 | required: false 8 | type: string 9 | secrets: 10 | OPENAI_API_KEY: 11 | required: false 12 | COHERE_API_KEY: 13 | required: false 14 | CODECOV_TOKEN: 15 | required: false 16 | PINECONE_API_KEY: 17 | required: false 18 | 19 | jobs: 20 | test: 21 | runs-on: ubuntu-latest 22 | timeout-minutes: 20 # Fail the job if it runs longer than 20 minutes 23 | strategy: 24 | matrix: 25 | python-version: ["3.10", "3.13"] 26 | steps: 27 | - uses: actions/checkout@v4 28 | 29 | - name: Set up Python 30 | uses: actions/setup-python@v4 31 | with: 32 | python-version: ${{ matrix.python-version }} 33 | 34 | - name: Install Dagger CLI 35 | run: | 36 | curl -L https://dl.dagger.io/dagger/install.sh | sh 37 | echo "$PWD/bin" >> $GITHUB_PATH 38 | 39 | - name: Run tests with Dagger 40 | run: | 41 | dagger call --mod ./.dagger test \ 42 | --src . \ 43 | --scope ${{ inputs.test_scope || 'all' }} \ 44 | --python-version "${{ matrix.python-version }}" \ 45 | --openai-api-key "$OPENAI_API_KEY" \ 46 | --cohere-api-key "$COHERE_API_KEY" \ 47 | --pinecone-api-key "$PINECONE_API_KEY" 48 | env: 49 | OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} 50 | COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }} 51 | PINECONE_API_KEY: ${{ secrets.PINECONE_API_KEY }} 52 | POSTGRES_HOST: postgres 53 | POSTGRES_PORT: 5432 54 | POSTGRES_DB: postgres 55 | POSTGRES_USER: postgres 56 | POSTGRES_PASSWORD: postgres 57 | 58 | - name: Upload coverage to Codecov 59 | uses: codecov/codecov-action@v2 60 | with: 61 | file: ./coverage.xml 62 | token: ${{ secrets.CODECOV_TOKEN }} 63 | fail_ci_if_error: false 64 | -------------------------------------------------------------------------------- /.github/workflows/triggers_merge.yml: -------------------------------------------------------------------------------- 1 | name: Merge Workflow 2 | 3 | permissions: 4 | id-token: write 5 | contents: read 6 | 7 | on: 8 | push: 9 | branches: 10 | - main 11 | 12 | jobs: 13 | 14 | diff: 15 | name: "Diff" 16 | uses: ./.github/workflows/diff.yml 17 | with: 18 | config: diff-config.yaml 19 | 20 | docs: 21 | needs: [ diff ] 22 | name: "Docs" 23 | if: ${{ fromJson(needs.diff.outputs.tags).docs.changed }} 24 | uses: ./.github/workflows/docs.yml 25 | secrets: 26 | PAT: ${{ secrets.PAT }} 27 | -------------------------------------------------------------------------------- /.github/workflows/triggers_pr.yml: -------------------------------------------------------------------------------- 1 | name: PR Workflow 2 | 3 | permissions: 4 | id-token: write 5 | contents: read 6 | 7 | on: 8 | pull_request: 9 | branches: [ "**" ] 10 | types: [ opened, edited, synchronize ] 11 | workflow_call: 12 | inputs: 13 | test_scope: 14 | required: false 15 | type: string 16 | secrets: 17 | OPENAI_API_KEY: 18 | required: false 19 | COHERE_API_KEY: 20 | required: false 21 | CODECOV_TOKEN: 22 | required: false 23 | PINECONE_API_KEY: 24 | required: false 25 | 26 | jobs: 27 | 28 | diff: 29 | name: "Diff" 30 | uses: ./.github/workflows/diff.yml 31 | with: 32 | config: diff-config.yaml 33 | 34 | lint: 35 | needs: [ diff ] 36 | name: "Lint" 37 | if: ${{ fromJson(needs.diff.outputs.tags).code.changed }} 38 | uses: ./.github/workflows/lint.yml 39 | 40 | tests: 41 | needs: [ diff ] 42 | name: "Tests" 43 | if: ${{ fromJson(needs.diff.outputs.tags).code.changed }} 44 | uses: ./.github/workflows/test.yml 45 | with: 46 | test_scope: ${{ github.event.pull_request.head.repo.full_name != github.repository && 'unit-functional' || 'all' }} 47 | secrets: 48 | OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} 49 | COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }} 50 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} 51 | PINECONE_API_KEY: ${{ secrets.PINECONE_API_KEY }} 52 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.pyc 3 | .venv 4 | .DS_Store 5 | venv/ 6 | /.vscode 7 | .vscode 8 | .idea 9 | .conda 10 | **/__pycache__ 11 | **/*.py[cod] 12 | 13 | 14 | # local env files 15 | .env*.local 16 | .env 17 | my.secrets 18 | mac.env 19 | local.mac.env 20 | 21 | # Code coverage history 22 | .coverage 23 | .coverage.* 24 | .pytest_cache 25 | coverage.xml 26 | test.py 27 | output 28 | node_modules 29 | package-lock.json 30 | package.json 31 | test.ipynb 32 | test_sync.ipynb 33 | ``` 34 | 35 | # docs 36 | docs/build 37 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3.13 3 | repos: 4 | - repo: meta 5 | hooks: 6 | - id: check-hooks-apply 7 | - id: check-useless-excludes 8 | 9 | - repo: https://github.com/psf/black 10 | rev: 23.9.1 11 | hooks: 12 | - id: black 13 | 14 | - repo: https://github.com/asottile/blacken-docs 15 | rev: 1.16.0 16 | hooks: 17 | - id: blacken-docs 18 | additional_dependencies: [ black==22.10.0 ] 19 | 20 | - repo: https://github.com/alessandrojcm/commitlint-pre-commit-hook 21 | rev: v9.11.0 22 | hooks: 23 | - id: commitlint 24 | stages: [ commit-msg ] 25 | additional_dependencies: [ '@commitlint/config-conventional' ] 26 | 27 | - repo: https://github.com/codespell-project/codespell 28 | rev: v2.2.4 29 | hooks: 30 | - id: codespell 31 | name: Run codespell to check for common misspellings in files 32 | language: python 33 | types: [ text ] 34 | args: [ "--write-changes", "--ignore-words-list", "asend" ] 35 | exclude: "uv.lock" 36 | 37 | - repo: https://github.com/pre-commit/pre-commit-hooks 38 | rev: v4.4.0 39 | hooks: 40 | - id: check-vcs-permalinks 41 | - id: end-of-file-fixer 42 | - id: trailing-whitespace 43 | args: [ --markdown-linebreak-ext=md ] 44 | - id: debug-statements 45 | - id: no-commit-to-branch 46 | - id: check-merge-conflict 47 | - id: check-toml 48 | - id: check-yaml 49 | args: [ '--unsafe' ] # for mkdocs.yml 50 | - id: detect-private-key 51 | 52 | - repo: https://github.com/commitizen-tools/commitizen 53 | rev: v3.13.0 54 | hooks: 55 | - id: commitizen 56 | - id: commitizen-branch 57 | stages: 58 | - post-commit 59 | - push 60 | 61 | - repo: https://github.com/astral-sh/ruff-pre-commit 62 | rev: v0.0.290 63 | hooks: 64 | - id: ruff 65 | types_or: [ python, pyi, jupyter ] 66 | 67 | - repo: https://github.com/PyCQA/bandit 68 | rev: 1.7.6 69 | hooks: 70 | - id: bandit 71 | args: [ '-lll' ] 72 | -------------------------------------------------------------------------------- /.pydoc-markdown.yml: -------------------------------------------------------------------------------- 1 | loaders: 2 | - type: python 3 | packages: 4 | - semantic_router 5 | search_path: ["."] 6 | 7 | processors: 8 | - type: filter 9 | skip_empty_modules: true 10 | exclude_private: true 11 | exclude_special: true 12 | documented_only: true 13 | do_not_filter_modules: true 14 | - type: smart 15 | - type: crossref 16 | 17 | renderer: 18 | type: docusaurus 19 | docs_base_path: docs/build 20 | relative_output_path: . -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.13 2 | -------------------------------------------------------------------------------- /CLAUDE.md: -------------------------------------------------------------------------------- 1 | # CLAUDE.md 2 | 3 | This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. 4 | 5 | ## Project Overview 6 | 7 | Semantic Router is a high-performance decision-making layer for LLMs and agents that uses semantic vector space to make routing decisions instead of waiting for slow LLM generations. It's maintained by Aurelio AI and provides both static and dynamic routing capabilities. 8 | 9 | ## Development Commands 10 | 11 | ### Package Management 12 | This project uses `uv` for dependency management. All commands should be prefixed with `uv run`. 13 | 14 | ### Testing 15 | ```bash 16 | # Run all tests with coverage 17 | make test 18 | 19 | # Run specific test categories 20 | make test_unit # Unit tests only 21 | make test_functional # Functional tests only 22 | make test_integration # Integration tests only 23 | 24 | # Run a single test file 25 | uv run pytest tests/unit/test_route.py -vv 26 | 27 | # Run a single test function 28 | uv run pytest tests/unit/test_route.py::test_function_name -vv 29 | ``` 30 | 31 | ### Code Quality 32 | ```bash 33 | # Run all linting and formatting 34 | make lint 35 | 36 | # Auto-fix and format code 37 | make format 38 | 39 | # Lint only files changed from main branch 40 | make lint_diff 41 | 42 | # Run type checking 43 | uv run mypy semantic_router/ 44 | ``` 45 | 46 | ### Running Examples 47 | Most examples are in Jupyter notebooks in the `docs/` directory. To run them: 48 | ```bash 49 | uv run jupyter notebook docs/00-introduction.ipynb 50 | ``` 51 | 52 | ## Architecture Overview 53 | 54 | ### Core Components 55 | 56 | 1. **Routes** (`semantic_router/route.py`) 57 | - Basic unit of routing logic 58 | - Contains name, utterances, optional function schemas and LLM 59 | - Supports both static (simple matching) and dynamic (parameter extraction) routes 60 | 61 | 2. **Routers** (`semantic_router/routers/`) 62 | - `BaseRouter`: Abstract base with core routing logic, training, and sync 63 | - `SemanticRouter`: Standard dense embedding-based router 64 | - `HybridRouter`: Combines dense and sparse embeddings for better accuracy 65 | 66 | 3. **Encoders** (`semantic_router/encoders/`) 67 | - `DenseEncoder`: Base class for semantic embeddings (OpenAI, Cohere, HuggingFace, etc.) 68 | - `SparseEncoder`: Base for keyword-based encodings (BM25, TF-IDF) 69 | - Supports both sync and async operations 70 | - Some encoders are asymmetric (different encoding for queries vs documents) 71 | 72 | 4. **Indexes** (`semantic_router/index/`) 73 | - `BaseIndex`: Abstract interface for vector storage 74 | - Implementations: `LocalIndex`, `PineconeIndex`, `PostgresIndex`, `QdrantIndex` 75 | - `HybridLocalIndex`: Supports both dense and sparse vectors 76 | 77 | ### Data Flow 78 | ``` 79 | User Query → Router → Encoder → Embeddings → Index → Similarity Search → Route Selection → Response 80 | ``` 81 | 82 | ### Key Patterns 83 | - **Strategy Pattern**: Swappable encoders and indexes 84 | - **Template Method**: BaseRouter defines algorithm, subclasses implement specifics 85 | - **Async Support**: Full async/await support throughout 86 | - **Configuration**: Routes can be imported/exported as JSON/YAML 87 | 88 | ## Important Considerations 89 | 90 | ### When Adding New Features 91 | 1. Check existing patterns in similar components (e.g., look at other encoders when adding a new encoder) 92 | 2. Ensure both sync and async versions are implemented where applicable 93 | 3. Add appropriate type hints and follow existing naming conventions 94 | 4. Add tests in the appropriate test directory (unit/functional/integration) 95 | 96 | ### Testing Guidelines 97 | - Unit tests go in `tests/unit/` and test individual components in isolation 98 | - Functional tests go in `tests/functional/` and test component interactions 99 | - Integration tests go in `tests/integration/` and test with real external services 100 | - Mock external API calls in unit tests 101 | - Use `pytest-mock` for mocking 102 | 103 | ### Common Gotchas 104 | 1. Many encoders require API keys set as environment variables (e.g., `OPENAI_API_KEY`) 105 | 2. The project supports Python 3.9-3.13, some features are disabled for 3.13+ due to dependency constraints 106 | 3. Local models require the `[local]` extra: `pip install "semantic-router[local]"` 107 | 4. Hybrid routing requires the `[hybrid]` extra 108 | 5. When working with indexes, be aware of synchronization between local and remote states 109 | 110 | ### Performance Considerations 111 | - Route encoding happens once during initialization 112 | - Query encoding happens on every request 113 | - Use `auto_sync="local"` for better performance when routes don't change frequently 114 | - HybridRouter is more accurate but slightly slower than SemanticRouter -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to the Semantic Router 2 | 3 | The Aurelio Team welcome and encourage any contributions to the Semantic Router, big or small. Please feel free to contribute to new features, bug fixes, or documentation. We're always eager to hear your suggestions. 4 | 5 | Please follow these guidelines when making a contribution: 6 | 1. **Check for Existing Issues:** Before making any changes, [check here for related issues](https://github.com/aurelio-labs/semantic-router/issues). 7 | 2. **Run Your Changes by Us!** If no related issue exists yet, please create one and suggest your changes. Checking in with the team first will allow us to determine if the changes are in scope. 8 | 3. **Set Up Development Environment** If the changes are agreed, then you can go ahead and set up a development environment (see [Setting Up Your Development Environment](#setting-up-your-development-environment) below). 9 | 4. **Create an Early Draft Pull Request** Once you have commits ready to be shared, initiate a draft Pull Request with an initial version of your implementation and request feedback. It's advisable not to wait until the feature is fully completed. 10 | 5. **Ensure that All Pull Request Checks Pass** There are Pull Request checks that need to be satifisfied before the changes can be merged. These appear towards the bottom of the Pull Request webpage on GitHub, and include: 11 | - Ensure that the Pull Request title is prepended with a [valid type](https://flank.github.io/flank/pr_titles/). E.g. `feat: My New Feature`. 12 | - Run linting (and fix any issues that are flagged) by: 13 | - Navigating to /semantic-router. 14 | - Running `make lint` to fix linting issues. 15 | - Running `black .` to fix `black` linting issues. 16 | - Running `ruff check . --fix` to fix `ruff` linting issues (where possible, others may need manual changes). 17 | - Running `mypy .` and then fixing any of the issues that are raised. 18 | - Confirming the linters pass using `make lint` again. 19 | - Ensure that, for any new code, new [PyTests are written](https://github.com/aurelio-labs/semantic-router/tree/main/tests/unit). If any code is removed, then ensure that corresponding PyTests are also removed. Finally, ensure that all remaining PyTests pass using `pytest ./tests` (to avoid integration tests you can run `pytest ./tests/unit`. 20 | - Codecov checks will inform you if any code is not covered by PyTests upon creating the PR. You should aim to cover new code with PyTests. 21 | 22 | > **Feedback and Discussion:** 23 | While we encourage you to initiate a draft Pull Request early to get feedback on your implementation, we also highly value discussions and questions. If you're unsure about any aspect of your contribution or need clarification on the project's direction, please don't hesitate to use the [Issues section](https://github.com/aurelio-labs/semantic-router/issues) of our repository. Engaging in discussions or asking questions before starting your work can help ensure that your efforts align well with the project's goals and existing work. 24 | 25 | # Setting Up Your Development Environment 26 | 27 | 1. Fork on GitHub: 28 | Go to the [repository's page](https://github.com/aurelio-labs/semantic-router) on GitHub: 29 | Click the "Fork" button in the top-right corner of the page. 30 | 31 | 2. Clone Your Fork: 32 | After forking, you'll be taken to your new fork of the repository on GitHub. Copy the URL of your fork from the address bar or by clicking the "Code" button and copying the URL under "Clone with HTTPS" or "Clone with SSH". 33 | Open your terminal or command prompt. 34 | Use the git clone command followed by the URL you copied to clone the repository to your local machine. Replace `https://github.com//.git` with the URL of your fork: 35 | ``` 36 | git clone https://github.com//.git 37 | ``` 38 | 39 | 3. Ensure you have [`uv` installed](https://docs.astral.sh/uv/getting-started/installation/), for macos and linux use `curl -LsSf https://astral.sh/uv/install.sh | sh`. 40 | 41 | 42 | 5. Then navigate to the cloned folder, create a virtualenv, and install via `uv`: 43 | ``` 44 | # Move into the cloned folder 45 | cd semantic-router/ 46 | 47 | # Create a virtual environment 48 | uv venv --python 3.13 49 | 50 | # Activate the environment 51 | source .venv/bin/activate 52 | 53 | # Install via uv with all extras relevant to perform unit tests 54 | uv sync --extra all 55 | ``` 56 | 57 | ## Developing the CI Pipeline with Dagger 58 | 59 | We use [Dagger](https://dagger.io) for our CI pipeline. This allows us to fully reproduce everything that is run in Github Actions locally. To develop the CI pipeline the following is recommended: 60 | 61 | 1. Install Dagger CLI for running CI/CD pipelines locally: 62 | - macOS: `brew install dagger/tap/dagger` 63 | - Linux: `curl -L https://dl.dagger.io/dagger/install.sh | sh` 64 | 65 | 2. Run unit test pipeline: 66 | 67 | ``` 68 | dagger call --mod ./.dagger unit-test --src . 69 | # `--mod ./.dagger` tells dagger to run from the ./.dagger directory 70 | # `--src .` runs the underlying pytest commands from the current directory (which should be the root SR folder) 71 | ``` 72 | 73 | 3. _(Optional)_ If you are modifying the CI pipeline itself and see issues that _may_ be due to caching of the pipeline, you can clear the cache like so: 74 | 75 | ``` 76 | docker stop $(docker ps -a -q -f 'name=dagger-engine') 77 | docker rm $(docker ps -a -q -f 'name=dagger-engine') 78 | docker system prune -f -a --volumes 79 | ``` 80 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Aurelio AI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | format: 2 | uv run ruff check . --fix 3 | 4 | PYTHON_FILES=. 5 | lint: PYTHON_FILES=. 6 | lint_diff: PYTHON_FILES=$(shell git diff --name-only --diff-filter=d main | grep -E '\.py$$') 7 | 8 | lint lint_diff: 9 | uv run ruff check . 10 | uv run ruff format . 11 | uv run mypy $(PYTHON_FILES) 12 | 13 | test: 14 | uv run pytest -vv --cov=semantic_router --cov-report=term-missing --cov-report=xml --exitfirst --maxfail=1 15 | 16 | test_functional: 17 | uv run pytest -vv -s --exitfirst --maxfail=1 tests/functional 18 | test_unit: 19 | uv run pytest -vv --exitfirst --maxfail=1 tests/unit 20 | 21 | test_integration: 22 | uv run pytest -vv --exitfirst --maxfail=1 tests/integration 23 | -------------------------------------------------------------------------------- /commitlint.config.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | extends: ['@commitlint/config-conventional'], 3 | rules: { 4 | 'subject-case': [0, 'never'], 5 | 'subject-max-length': [2, 'always', 60], 6 | 'type-enum': [2, 'always', ['build', 'chore', 'ci', 'docs', 'feat', 'fix', 'perf', 'refactor', 'style', 'test']], 7 | }, 8 | }; 9 | -------------------------------------------------------------------------------- /diff-config.yaml: -------------------------------------------------------------------------------- 1 | modules: 2 | semantic_router: 3 | tags: [ code, docs ] 4 | pattern: semantic_router/** 5 | tests: 6 | tags: [ code ] 7 | pattern: tests/** 8 | docs: 9 | tags: [ docs ] 10 | pattern: docs/** 11 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/encoders/local-encoder.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Using LocalEncoder for Local Embeddings\n", 8 | "\n", 9 | "This notebook demonstrates how to use the `LocalEncoder` from `semantic-router` to generate embeddings locally using [sentence-transformers](https://www.sbert.net/).\n", 10 | "\n", 11 | "No API key is required. All computation happens on your machine (CPU, CUDA, or MPS)." 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "# Install dependencies if needed\n", 21 | "# !pip install -qU \"semantic-router[local]\"" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "## Example Texts\n", 29 | "Let's define a few example texts to embed." 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 1, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "docs = [\n", 39 | " \"The quick brown fox jumps over the lazy dog.\",\n", 40 | " \"Artificial intelligence is transforming the world.\",\n", 41 | " \"Semantic search improves information retrieval.\",\n", 42 | " \"Local models run without internet access.\",\n", 43 | " \"Sentence Transformers provide high-quality embeddings.\",\n", 44 | "]" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "metadata": {}, 50 | "source": [ 51 | "## Initialize LocalEncoder\n", 52 | "You can specify a model from [sentence-transformers](https://www.sbert.net/docs/pretrained_models.html). The default is `BAAI/bge-small-en-v1.5`." 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "from semantic_router.encoders.local import LocalEncoder\n", 62 | "\n", 63 | "encoder = LocalEncoder() # You can specify name='all-MiniLM-L6-v2', etc.\n", 64 | "print(f\"Using model: {encoder.name}\")\n", 65 | "print(f\"Device: {encoder.device}\")" 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "metadata": {}, 71 | "source": [ 72 | "## Encode the Texts\n", 73 | "Let's generate embeddings for our example texts." 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "embeddings = encoder(docs)\n", 83 | "print(f\"Generated {len(embeddings)} embeddings. Example:\")\n", 84 | "print(embeddings[0][:10]) # Show first 10 dimensions of the first embedding" 85 | ] 86 | }, 87 | { 88 | "cell_type": "markdown", 89 | "metadata": {}, 90 | "source": [ 91 | "## Notes\n", 92 | "- Embeddings are computed locally, so no data leaves your machine.\n", 93 | "- You can use any compatible sentence-transformers model by changing the `name` parameter.\n", 94 | "- The encoder will use CUDA or MPS if available, otherwise CPU.\n", 95 | "- Embeddings are normalized by default (L2 norm = 1).\n", 96 | "\n", 97 | "For more details, see the [sentence-transformers documentation](https://www.sbert.net/)." 98 | ] 99 | } 100 | ], 101 | "metadata": { 102 | "kernelspec": { 103 | "display_name": ".venv", 104 | "language": "python", 105 | "name": "python3" 106 | }, 107 | "language_info": { 108 | "codemirror_mode": { 109 | "name": "ipython", 110 | "version": 3 111 | }, 112 | "file_extension": ".py", 113 | "mimetype": "text/x-python", 114 | "name": "python", 115 | "nbconvert_exporter": "python", 116 | "pygments_lexer": "ipython3", 117 | "version": "3.13.2" 118 | } 119 | }, 120 | "nbformat": 4, 121 | "nbformat_minor": 2 122 | } 123 | -------------------------------------------------------------------------------- /docs/get-started/introduction.md: -------------------------------------------------------------------------------- 1 | Semantic Router is a superfast decision-making layer for LLMs and agents. Instead of waiting for slow LLM generations to make tool-use decisions, it uses semantic vector space to route requests based on meaning. 2 | 3 | ## What is Semantic Router? 4 | 5 | Semantic Router enables: 6 | 7 | - **Faster decisions**: Make routing decisions in milliseconds rather than seconds 8 | - **Lower costs**: Avoid expensive LLM inference for simple routing tasks 9 | - **Better control**: Direct conversations, queries, and agent actions with precision 10 | - **Full flexibility**: Use cloud APIs or run everything locally 11 | 12 | ## Key Features 13 | 14 | - **Simple API**: Set up routes with just a few lines of code 15 | - **Dynamic routes**: Generate parameters and trigger function calls 16 | - **Multiple integrations**: Works with Cohere, OpenAI, Hugging Face, FastEmbed, and more 17 | - **Vector store support**: Integrates with Pinecone and Qdrant for persistence 18 | - **Multi-modal capabilities**: Route based on image content, not just text 19 | - **Local execution**: Run entirely on your machine with no API dependencies 20 | 21 | ## Version 0.1 Released 22 | 23 | Semantic Router v0.1 is now available! If you're migrating from an earlier version, please see our [migration guide](../user-guide/guides/migration-to-v1). 24 | 25 | ## Getting Started 26 | 27 | For a quick introduction to using Semantic Router, check out our [quickstart guide](quickstart). 28 | 29 | ## Execution Options 30 | 31 | Semantic Router supports multiple execution modes: 32 | 33 | - **Cloud-based**: Using OpenAI, Cohere, or other API-based embeddings 34 | - **Hybrid**: Combining local embeddings with API-based LLMs 35 | - **Fully local**: Run everything on your machine with models like Llama and Mistral 36 | 37 | ## Resources 38 | 39 | - [Documentation](https://docs.aurelio.ai/semantic-router/index.html) 40 | - [GitHub Repository](https://github.com/aurelio-labs/semantic-router) 41 | - [Online Course](https://www.aurelio.ai/course/semantic-router) -------------------------------------------------------------------------------- /docs/get-started/quickstart.md: -------------------------------------------------------------------------------- 1 | *Semantic-router* is a lightweight library that helps you intelligently route text to the right handlers based on meaning rather than exact keyword matching. It's perfect for building chatbots, classification systems, or any application that needs to understand user intent. 2 | 3 | To get started with *semantic-router* we install it like so: 4 | 5 | ```bash 6 | pip install -qU semantic-router 7 | ``` 8 | 9 | > **Warning** 10 | > If wanting to use a fully local version of semantic router you can use `HuggingFaceEncoder` and `LlamaCppLLM` (`pip install -qU "semantic-router[local]"`, see [here](../user-guide/guides/local-execution)). To use the `HybridRouteLayer` you must `pip install -qU "semantic-router[hybrid]"`. 11 | 12 | ## Defining Routes 13 | 14 | We begin by defining a set of `Route` objects. A Route represents a specific topic or intent that you want to detect in user input. Each Route is defined by example utterances that serve as a semantic reference point. 15 | 16 | Let's try two simple routes for now — one for talk on *politics* and another for *chitchat*: 17 | 18 | ```python 19 | from semantic_router import Route 20 | 21 | # we could use this as a guide for our chatbot to avoid political conversations 22 | politics = Route( 23 | name="politics", 24 | utterances=[ 25 | "isn't politics the best thing ever", 26 | "why don't you tell me about your political opinions", 27 | "don't you just love the president", 28 | "they're going to destroy this country!", 29 | "they will save the country!", 30 | ], 31 | ) 32 | 33 | # this could be used as an indicator to our chatbot to switch to a more 34 | # conversational prompt 35 | chitchat = Route( 36 | name="chitchat", 37 | utterances=[ 38 | "how's the weather today?", 39 | "how are things going?", 40 | "lovely weather today", 41 | "the weather is horrendous", 42 | "let's go to the chippy", 43 | ], 44 | ) 45 | 46 | # we place both of our decisions together into single list 47 | routes = [politics, chitchat] 48 | ``` 49 | 50 | ## Setting Up an Encoder 51 | 52 | With our routes ready, now we initialize an embedding / encoder model. The encoder converts text into numerical vectors, allowing the system to measure semantic similarity. We currently support `CohereEncoder` and `OpenAIEncoder` — more encoders will be added soon. 53 | 54 | To initialize them: 55 | 56 | ```python 57 | import os 58 | from semantic_router.encoders import CohereEncoder, OpenAIEncoder 59 | 60 | # for Cohere 61 | os.environ["COHERE_API_KEY"] = "" 62 | encoder = CohereEncoder() 63 | 64 | # or for OpenAI 65 | os.environ["OPENAI_API_KEY"] = "" 66 | encoder = OpenAIEncoder() 67 | ``` 68 | 69 | ## Creating a RouteLayer 70 | 71 | With our `routes` and `encoder` defined we now create a `SemanticRouter`. The SemanticRouter is the decision-making engine that compares incoming text against your routes to find the best semantic match. 72 | 73 | ```python 74 | from semantic_router.routers import SemanticRouter 75 | 76 | rl = SemanticRouter(encoder=encoder, routes=routes) 77 | ``` 78 | 79 | ## Making Routing Decisions 80 | 81 | We can now use our route layer to make super fast routing decisions based on user queries. Behind the scenes, the system converts both your example utterances and the incoming query into vectors and finds the closest match. 82 | 83 | Let's try with two queries that should trigger our route decisions: 84 | 85 | ```python 86 | rl("don't you love politics?").name 87 | ``` 88 | 89 | ``` 90 | [Out]: 'politics' 91 | ``` 92 | 93 | Correct decision, let's try another: 94 | 95 | ```python 96 | rl("how's the weather today?").name 97 | ``` 98 | 99 | ``` 100 | [Out]: 'chitchat' 101 | ``` 102 | 103 | We get both decisions correct! The power of semantic routing is that it works even when queries don't exactly match your examples but are similar in meaning. 104 | 105 | ## Handling Unmatched Queries 106 | 107 | Now let's try sending an unrelated query: 108 | 109 | ```python 110 | rl("I'm interested in learning about llama 2").name 111 | ``` 112 | 113 | ``` 114 | [Out]: 115 | ``` 116 | 117 | In this case, no decision could be made as we had no semantic matches — so our route layer returned `None`! This feature is useful for creating fallback behavior or passthroughs in your applications when no intent is clearly matched. -------------------------------------------------------------------------------- /docs/indexes/postgres/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM postgres:latest 2 | 3 | RUN apt-get update && \ 4 | apt-get install -y build-essential postgresql-server-dev-all git && \ 5 | git clone https://github.com/pgvector/pgvector.git && \ 6 | cd pgvector && \ 7 | make && \ 8 | make install && \ 9 | cd .. && \ 10 | rm -rf pgvector && \ 11 | apt-get remove -y build-essential postgresql-server-dev-all git && \ 12 | apt-get autoremove -y && \ 13 | apt-get clean 14 | -------------------------------------------------------------------------------- /docs/indexes/postgres/postgres.compose.yaml: -------------------------------------------------------------------------------- 1 | version: '3.8' 2 | 3 | services: 4 | pgvector: 5 | build: . 6 | environment: 7 | POSTGRES_DB: semantic_router 8 | POSTGRES_USER: admin 9 | POSTGRES_PASSWORD: root 10 | volumes: 11 | - db_data:/var/lib/postgresql/data 12 | ports: 13 | - "5432:5432" 14 | 15 | volumes: 16 | db_data: 17 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/user-guide/changelog.md: -------------------------------------------------------------------------------- 1 | ### v0.1.10 2 | 3 | The `0.1.10` release was primarily focused on expanding async support for `QdrantIndex`, `PostgresIndex`, and `HybridRouter`, alongside many synchronization and testing improvements. 4 | 5 | #### Feature: Expanded Async Support 6 | 7 | - **QdrantIndex**: Async methods have been brought inline with our other indexes, ensuring consistent behavior. 8 | - **PostgresIndex**: Async methods have been added to the `PostgresIndex` for improved performance in async environments. 9 | - **HybridRouter**: Async support for the `HybridRouter` is now aligned with the `SemanticRouter`, providing a more consistent experience. 10 | 11 | #### Fixes and Optimizations 12 | 13 | - **LocalIndex Bug Fix**: Added a `metadata` attribute to the local index. This fixes a bug where `LocalIndex` embeddings would always be recomputed, as reported in [issue #585](https://github.com/aurelio-labs/semantic-router/issues/585). 14 | - Various other bug fixes and optimizations have been included in this release. 15 | - The `urllib3` library has been upgraded. 16 | - Test compatibility and synchronization have been optimized. 17 | 18 | --- 19 | 20 | ### v0.1.9 21 | 22 | The `0.1.9` update focuses on improving support for our local deployment options. We have standardized the `PostgresIndex` to bring it in line with other index options and prepare it for future feature releases. For `local` extras (inclusing `transformers` and `llama_cpp` support) and `postgres` extras we have resolved issues making those extras unusable with Python 3.13. 23 | 24 | Continue reading below for more detail. 25 | 26 | #### Feature: Improvements to `postgres` Support 27 | 28 | - **Upgrade from psycopg2 to v3 (psycopg)**: We've upgraded our PostgreSQL driver from psycopg2 to the newer psycopg v3, which provides better performance and modern Python compatibility. 29 | 30 | - **Standardization of sync methods**: The `PostgresIndex` class now has standardized synchronous methods that integrate seamlessly into our standard pytest test suite, ensuring better testing coverage and reliability. 31 | 32 | - **Addition of IndexType support**: We've introduced `IndexType` which includes `FLAT`, `IVFFLAT`, and `HNSW` index types. Both `FLAT` and `IVFFLAT` map to the `IVFFLAT` index in pgvector (pgvector does not currently support `FLAT` and so `IVFFLAT` is the closest approximation - but our other indexes do). `FLAT` is now the default index type. We recommend `HNSW` for high-scale use cases with many millions of utterances due to its higher memory requirements and complexity. 33 | 34 | #### Feature: Configurable Logging 35 | 36 | - **Environment variable log level control**: Logger now supports configurable log levels via environment variables. Set `SEMANTIC_ROUTER_LOG_LEVEL` or `LOG_LEVEL` to control logging verbosity (e.g., `DEBUG`, `INFO`, `WARNING`, `ERROR`). Defaults to `INFO` if not specified. 37 | 38 | #### Fix: Local Installation for Python 3.13 and up 39 | 40 | Python 3.13 and up had originally been incompatible with our local installations due to the lack of compatability with PyTorch and the (then) new version of Python. We have now added the following version support for `local` libraries: 41 | 42 | ```python 43 | local = [ 44 | "torch>=2.6.0 ; python_version >= '3.13'", 45 | "transformers>=4.50.0 ; python_version >= '3.13'", 46 | "tokenizer>=0.20.2 ; python_version >= '3.13'", 47 | "llama-cpp-python>=0.3.0 ; python_version >= '3.13'", 48 | ] 49 | ``` 50 | 51 | #### Fix: Consistent Index Length Behavior 52 | 53 | - **Standardized `__len__` method across all index implementations**: All index types now consistently return `0` when uninitialized, preventing potential `AttributeError` exceptions that could occur in `PineconeIndex` and `QdrantIndex` when checking the length of uninitialized indexes. 54 | 55 | #### Chore: Broader and More Useful Tests 56 | 57 | We have broken our tests apart into strict unit and integration test directories. Now, when incoming PRs are raised we will no longer trigger integration tests that require API keys to successfully run. To ensure we're still covering all components of the library we have broadened our testing suite to extensively test `LocalIndex`, `PineconeIndex` (via Pinecone Local), `PostgresIndex`, and `QdrantIndex` within those unit tests. -------------------------------------------------------------------------------- /docs/user-guide/concepts/architecture.md: -------------------------------------------------------------------------------- 1 | # Semantic Router: Technical Architecture 2 | 3 | ## System Overview 4 | 5 | Semantic Router is built around three core components that work together to enable intelligent routing of inputs based on semantic meaning. 6 | 7 | ```mermaid 8 | graph TD 9 | A[Input Query] --> B[Encoder] 10 | B --> C[Vector Embedding] 11 | C --> D[Router] 12 | E[Routes] --> D 13 | F[Index] <--> D 14 | D --> G[Matched Route] 15 | G --> H[Response Handler] 16 | ``` 17 | 18 | ## Core Components 19 | 20 | ### 1. Encoders 21 | 22 | Encoders transform inputs into vector representations in semantic space. 23 | 24 | ```mermaid 25 | classDiagram 26 | class BaseEncoder { 27 | +encode(text: List[str]) -> Any 28 | +aencode(text: List[str]) -> Any 29 | } 30 | class DenseEncoder { 31 | +encode() -> dense vectors 32 | } 33 | class SparseEncoder { 34 | +encode() -> sparse vectors 35 | } 36 | BaseEncoder <|-- DenseEncoder 37 | BaseEncoder <|-- SparseEncoder 38 | DenseEncoder <|-- OpenAIEncoder 39 | DenseEncoder <|-- HuggingFaceEncoder 40 | DenseEncoder <|-- CLIPEncoder 41 | SparseEncoder <|-- AurelioSparseEncoder 42 | SparseEncoder <|-- BM25Encoder 43 | SparseEncoder <|-- TFIDFEncoder 44 | ``` 45 | 46 | **Types of Encoders:** 47 | - **Dense encoders**: Generate continuous vectors (OpenAI, HuggingFace, etc.) 48 | - **Sparse encoders**: Generate sparse vectors (BM25, TFIDF, AurelioSparse, etc.) 49 | - **Multimodal encoders**: Handle images and text (CLIP, ViT) 50 | 51 | ### 2. Routes 52 | 53 | Routes define patterns to match against, with examples of inputs that should trigger them. 54 | 55 | ```mermaid 56 | classDiagram 57 | class Route { 58 | +name: str 59 | +utterances: List[str] 60 | +description: Optional[str] 61 | +function_schemas: Optional[List[Dict]] 62 | +score_threshold: Optional[float] 63 | +metadata: Optional[Dict] 64 | } 65 | ``` 66 | 67 | **Key properties:** 68 | - **name**: Identifier for the route 69 | - **utterances**: Example inputs that should match this route 70 | - **function_schemas**: Optional specifications for function calling 71 | - **score_threshold**: Minimum similarity score required to match 72 | 73 | ### 3. Indexing Systems 74 | 75 | Indexes store and retrieve route vectors efficiently. 76 | 77 | ```mermaid 78 | classDiagram 79 | class BaseIndex { 80 | +add(embeddings, routes, utterances) 81 | +query(vector, top_k) -> matches 82 | +delete(route_name) 83 | } 84 | BaseIndex <|-- LocalIndex 85 | BaseIndex <|-- PostgresIndex 86 | BaseIndex <|-- PineconeIndex 87 | BaseIndex <|-- QdrantIndex 88 | LocalIndex <|-- HybridLocalIndex 89 | ``` 90 | 91 | **Index types:** 92 | - **LocalIndex**: In-memory vector storage for dense embeddings 93 | - **HybridLocalIndex**: In-memory storage supporting both dense and sparse vectors 94 | - **PineconeIndex/QdrantIndex**: Cloud-based vector DBs 95 | - **PostgresIndex**: SQL-based vector storage 96 | 97 | ## Data Flow 98 | 99 | ```mermaid 100 | sequenceDiagram 101 | participant User 102 | participant Router 103 | participant Encoder 104 | participant Index 105 | User->>Router: send query 106 | Router->>Encoder: encode query 107 | Encoder->>Router: return vector 108 | Router->>Index: search for similar routes 109 | Index->>Router: return matches 110 | Router->>User: return best matched route 111 | ``` 112 | 113 | 1. **Input Reception**: The system receives an input (text, image) 114 | 2. **Encoding**: The input is transformed into a vector representation 115 | 3. **Retrieval**: The vector is compared against stored route vectors 116 | 4. **Matching**: The best matching route is selected based on similarity 117 | 5. **Response**: The system returns the matched route, enabling appropriate handling 118 | 119 | ## Router Types 120 | 121 | ```mermaid 122 | classDiagram 123 | class BaseRouter { 124 | +__call__(query) -> RouteChoice 125 | +acall(query) -> RouteChoice 126 | +add(routes) 127 | +route(query) -> RouteChoice 128 | } 129 | BaseRouter <|-- SemanticRouter 130 | BaseRouter <|-- HybridRouter 131 | ``` 132 | 133 | - **SemanticRouter**: Uses dense vector embeddings for semantic matching 134 | - **HybridRouter**: Combines both dense and sparse vectors for enhanced accuracy 135 | 136 | ## Integration Example 137 | 138 | ```python 139 | from semantic_router import Route, SemanticRouter 140 | from semantic_router.encoders import OpenAIEncoder 141 | 142 | # 1. Define routes 143 | weather_route = Route(name="weather", utterances=["What's the weather like?"]) 144 | greeting_route = Route(name="greeting", utterances=["Hello there!", "Hi!"]) 145 | 146 | # 2. Initialize encoder 147 | encoder = OpenAIEncoder() 148 | 149 | # 3. Create router with routes 150 | router = SemanticRouter(encoder=encoder, routes=[weather_route, greeting_route]) 151 | 152 | # 4. Route an incoming query 153 | result = router("What's the forecast for tomorrow?") 154 | print(result.name) # "weather" 155 | ``` 156 | 157 | ## Performance Considerations 158 | 159 | - **In-memory vs. Vector DB**: Choose based on scale and latency requirements 160 | - **Encoder selection**: Balance accuracy vs. speed based on use case 161 | - **Batch processing**: Use batch methods for higher throughput 162 | - **Async support**: Available for high-concurrency environments and applications relying 163 | on heavy network use 164 | -------------------------------------------------------------------------------- /docs/user-guide/features/route-filter.md: -------------------------------------------------------------------------------- 1 | We can filter the routes that the `SemanticRouter` considers when making a classification. This can be useful if we want to restrict the scope of possible routes based on some context. 2 | 3 | For example, we may have a router with several routes, `politics`, `weather`, `chitchat`, etc. We may want to restrict the scope of the classification to only consider the `chitchat` route. We can do this by passing a `route_filter` argument to our `SemanticRouter` calls like so: 4 | 5 | ```python 6 | sr("don't you love politics?", route_filter=["chitchat"]) 7 | ``` 8 | 9 | In this case, the `SemanticRouter` will only consider the `chitchat` route for the classification. 10 | 11 | ## Full Example 12 | 13 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/aurelio-labs/semantic-router/blob/main/docs/09-route-filter.ipynb) 14 | [![Open nbviewer](https://raw.githubusercontent.com/pinecone-io/examples/master/assets/nbviewer-shield.svg)](https://nbviewer.org/github/aurelio-labs/semantic-router/blob/main/docs/00-introduction.ipynb) 15 | 16 | We start by installing the library: 17 | 18 | ```python 19 | !pip install -qU semantic-router 20 | ``` 21 | 22 | We start by defining a dictionary mapping routes to example phrases that should trigger those routes. 23 | 24 | ```python 25 | from semantic_router import Route 26 | 27 | politics = Route( 28 | name="politics", 29 | utterances=[ 30 | "isn't politics the best thing ever", 31 | "why don't you tell me about your political opinions", 32 | "don't you just love the president", 33 | "don't you just hate the president", 34 | "they're going to destroy this country!", 35 | "they will save the country!", 36 | ], 37 | ) 38 | ``` 39 | 40 | Let's define another for good measure: 41 | 42 | ```python 43 | chitchat = Route( 44 | name="chitchat", 45 | utterances=[ 46 | "how's the weather today?", 47 | "how are things going?", 48 | "lovely weather today", 49 | "the weather is horrendous", 50 | "let's go to the chippy", 51 | ], 52 | ) 53 | 54 | routes = [politics, chitchat] 55 | ``` 56 | 57 | Now we initialize our embedding model: 58 | 59 | ```python 60 | import os 61 | from getpass import getpass 62 | from semantic_router.encoders import CohereEncoder, OpenAIEncoder 63 | 64 | os.environ["COHERE_API_KEY"] = os.getenv("COHERE_API_KEY") or getpass( 65 | "Enter Cohere API Key: " 66 | ) 67 | # os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY") or getpass( 68 | # "Enter OpenAI API Key: " 69 | # ) 70 | 71 | encoder = CohereEncoder() 72 | # encoder = OpenAIEncoder() 73 | ``` 74 | 75 | Now we define the `SemanticRouter`. When called, the router will consume text (a query) and output the category (`Route`) it belongs to — to initialize a `SemanticRouter` we need our `encoder` model and a list of `routes`. 76 | 77 | ```python 78 | from semantic_router.routers import SemanticRouter 79 | 80 | sr = SemanticRouter(encoder=encoder, routes=routes) 81 | ``` 82 | 83 | Now we can test it: 84 | 85 | ```python 86 | sr("don't you love politics?") 87 | ``` 88 | 89 | ``` 90 | RouteChoice(name='politics', function_call=None, similarity_score=None) 91 | ``` 92 | 93 | ```python 94 | sr("how's the weather today?") 95 | ``` 96 | 97 | ``` 98 | RouteChoice(name='chitchat', function_call=None, similarity_score=None) 99 | ``` 100 | 101 | Both are classified accurately, what if we send a query that is unrelated to our existing `Route` objects? 102 | 103 | ```python 104 | sr("I'm interested in learning about llama 2") 105 | ``` 106 | 107 | ``` 108 | RouteChoice(name=None, function_call=None, similarity_score=None) 109 | ``` 110 | 111 | In this case, we return `None` because no matches were identified. 112 | 113 | ## Demonstrating the Filter Feature 114 | 115 | Now, let's demonstrate the filter feature. We can specify a subset of routes to consider when making a classification. This can be useful if we want to restrict the scope of possible routes based on some context. 116 | 117 | For example, let's say we only want to consider the "chitchat" route for a particular query: 118 | 119 | ```python 120 | sr("don't you love politics?", route_filter=["chitchat"]) 121 | ``` 122 | 123 | ``` 124 | RouteChoice(name='chitchat', function_call=None, similarity_score=None) 125 | ``` 126 | 127 | Even though the query might be more related to the "politics" route, it will be classified as "chitchat" because we've restricted the routes to consider. 128 | 129 | Similarly, we can restrict it to the "politics" route: 130 | 131 | ```python 132 | sr("how's the weather today?", route_filter=["politics"]) 133 | ``` 134 | 135 | ``` 136 | RouteChoice(name=None, function_call=None, similarity_score=None) 137 | ``` 138 | 139 | In this case, it will return `None` because the query doesn't match the "politics" route well enough to pass the threshold. -------------------------------------------------------------------------------- /docs/user-guide/guides/configuration.md: -------------------------------------------------------------------------------- 1 | # Configuration 2 | 3 | This guide covers various configuration options available in semantic-router. 4 | 5 | ## Logging Configuration 6 | 7 | Semantic-router uses Python's logging module for debugging and monitoring. You can control the verbosity of logs using environment variables. 8 | 9 | ### Setting Log Levels 10 | 11 | You can configure the log level in two ways: 12 | 13 | 1. **Using the semantic-router specific variable (recommended):** 14 | ```bash 15 | export SEMANTIC_ROUTER_LOG_LEVEL=DEBUG 16 | ``` 17 | 18 | 2. **Using the general LOG_LEVEL variable:** 19 | ```bash 20 | export LOG_LEVEL=WARNING 21 | ``` 22 | 23 | The library checks for `SEMANTIC_ROUTER_LOG_LEVEL` first, then falls back to `LOG_LEVEL`. If neither is set, it defaults to `INFO`. 24 | 25 | ### Available Log Levels 26 | 27 | - `DEBUG`: Detailed information for diagnosing problems 28 | - `INFO`: General informational messages (default) 29 | - `WARNING`: Warning messages for potentially problematic situations 30 | - `ERROR`: Error messages for serious problems 31 | - `CRITICAL`: Critical messages for very serious errors 32 | 33 | ### Example Usage 34 | 35 | ```python 36 | import os 37 | # Set before importing semantic-router 38 | os.environ["SEMANTIC_ROUTER_LOG_LEVEL"] = "DEBUG" 39 | 40 | from semantic_router import Route, SemanticRouter 41 | # Your debug logs will now be visible 42 | ``` 43 | 44 | This is particularly useful when: 45 | - Debugging encoder or index issues 46 | - Monitoring route matching decisions 47 | - Troubleshooting performance problems 48 | - Understanding the library's internal behavior -------------------------------------------------------------------------------- /docs/user-guide/guides/save-load-from-file.md: -------------------------------------------------------------------------------- 1 | Route layers can be saved to and loaded from files. This can be useful if we want to save a route layer to a file for later use, or if we want to load a route layer from a file. 2 | 3 | We can save and load route layers to/from YAML or JSON files. For JSON we do: 4 | 5 | ```python 6 | # save to JSON 7 | router.to_json("router.json") 8 | # load from JSON 9 | new_router = SemanticRouter.from_json("router.json") 10 | ``` 11 | 12 | For YAML we do: 13 | 14 | ```python 15 | # save to YAML 16 | router.to_yaml("router.yaml") 17 | # load from YAML 18 | new_router = SemanticRouter.from_yaml("router.yaml") 19 | ``` 20 | 21 | The saved files contain all the information needed to initialize new semantic routers. If you are using a remote index, you can use the [sync features](../features/sync) to keep the router in sync with the index. 22 | 23 | ## Full Example 24 | 25 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/aurelio-labs/semantic-router/blob/main/docs/01-save-load-from-file.ipynb) 26 | [![Open nbviewer](https://raw.githubusercontent.com/pinecone-io/examples/master/assets/nbviewer-shield.svg)](https://nbviewer.org/github/aurelio-labs/semantic-router/blob/main/docs/01-save-load-from-file.ipynb) 27 | 28 | Here we will show how to save routers to YAML or JSON files, and how to load a router from file. 29 | 30 | We start by installing the library: 31 | 32 | ```bash 33 | !pip install -qU semantic-router 34 | ``` 35 | 36 | ## Define Route 37 | 38 | First let's create a list of routes: 39 | 40 | ```python 41 | from semantic_router import Route 42 | 43 | politics = Route( 44 | name="politics", 45 | utterances=[ 46 | "isn't politics the best thing ever", 47 | "why don't you tell me about your political opinions", 48 | "don't you just love the president", 49 | "don't you just hate the president", 50 | "they're going to destroy this country!", 51 | "they will save the country!", 52 | ], 53 | ) 54 | chitchat = Route( 55 | name="chitchat", 56 | utterances=[ 57 | "how's the weather today?", 58 | "how are things going?", 59 | "lovely weather today", 60 | "the weather is horrendous", 61 | "let's go to the chippy", 62 | ], 63 | ) 64 | 65 | routes = [politics, chitchat] 66 | ``` 67 | 68 | We define a semantic router using these routes and using the Cohere encoder. 69 | 70 | ```python 71 | import os 72 | from getpass import getpass 73 | from semantic_router import SemanticRouter 74 | from semantic_router.encoders import CohereEncoder 75 | 76 | # dashboard.cohere.ai 77 | os.environ["COHERE_API_KEY"] = os.getenv("COHERE_API_KEY") or getpass( 78 | "Enter Cohere API Key: " 79 | ) 80 | 81 | encoder = CohereEncoder() 82 | 83 | router = SemanticRouter(encoder=encoder, routes=routes, auto_sync="local") 84 | ``` 85 | 86 | ## Test Route 87 | 88 | ```python 89 | router("isn't politics the best thing ever") 90 | ``` 91 | 92 | Output: 93 | ``` 94 | RouteChoice(name='politics', function_call=None, similarity_score=None) 95 | ``` 96 | 97 | ```python 98 | router("how's the weather today?") 99 | ``` 100 | 101 | Output: 102 | ``` 103 | RouteChoice(name='chitchat', function_call=None, similarity_score=None) 104 | ``` 105 | 106 | ## Save To JSON 107 | 108 | To save our semantic router we call the `to_json` method: 109 | 110 | ```python 111 | router.to_json("router.json") 112 | ``` 113 | 114 | ## Loading from JSON 115 | 116 | We can view the router file we just saved to see what information is stored. 117 | 118 | ```python 119 | import json 120 | 121 | with open("router.json", "r") as f: 122 | router_json = json.load(f) 123 | 124 | print(router_json) 125 | ``` 126 | 127 | It tells us our encoder type, encoder name, and routes. This is everything we need to initialize a new router. To do so, we use the `from_json` method. 128 | 129 | ```python 130 | router = SemanticRouter.from_json("router.json") 131 | ``` 132 | 133 | We can confirm that our router has been initialized with the expected attributes by viewing the `SemanticRouter` object: 134 | 135 | ```python 136 | print( 137 | f"""{router.encoder.type=} 138 | {router.encoder.name=} 139 | {router.routes=}""" 140 | ) 141 | ``` 142 | 143 | --- 144 | 145 | ## Test Route Again 146 | 147 | ```python 148 | router("isn't politics the best thing ever") 149 | ``` 150 | 151 | Output: 152 | ``` 153 | RouteChoice(name='politics', function_call=None, similarity_score=None) 154 | ``` 155 | 156 | ```python 157 | router("how's the weather today?") 158 | ``` 159 | 160 | Output: 161 | ``` 162 | RouteChoice(name='chitchat', function_call=None, similarity_score=None) 163 | ``` -------------------------------------------------------------------------------- /docs/user-guide/guides/semantic-router.md: -------------------------------------------------------------------------------- 1 | The `SemanticRouter` is the main class of the semantic router. It is responsible 2 | for making decisions about which route to take based on an input utterance. 3 | A `SemanticRouter` consists of an `encoder`, an `index`, and a list of `routes`. 4 | Route layers that include dynamic routes (i.e. routes that can generate dynamic 5 | decision outputs) also include an `llm`. 6 | 7 | To use a `SemanticRouter` we first need some `routes`. We can initialize them like 8 | so: 9 | 10 | ```python 11 | from semantic_router import Route 12 | 13 | politics = Route( 14 | name="politics", 15 | utterances=[ 16 | "isn't politics the best thing ever", 17 | "why don't you tell me about your political opinions", 18 | "don't you just love the president", 19 | "don't you just hate the president", 20 | "they're going to destroy this country!", 21 | "they will save the country!", 22 | ], 23 | ) 24 | 25 | chitchat = Route( 26 | name="chitchat", 27 | utterances=[ 28 | "how's the weather today?", 29 | "how are things going?", 30 | "lovely weather today", 31 | "the weather is horrendous", 32 | "let's go to the chippy", 33 | ], 34 | ) 35 | ``` 36 | 37 | We initialize an encoder — there are many options available here, from local 38 | to API-based. For now we'll use the `OpenAIEncoder`. 39 | 40 | ```python 41 | import os 42 | from semantic_router.encoders import OpenAIEncoder 43 | 44 | os.environ["OPENAI_API_KEY"] = "" 45 | 46 | encoder = OpenAIEncoder() 47 | ``` 48 | 49 | Now we define the `RouteLayer`. When called, the route layer will consume text 50 | (a query) and output the category (`Route`) it belongs to — to initialize a 51 | `RouteLayer` we need our `encoder` model and a list of `routes`. 52 | 53 | ```python 54 | from semantic_router import SemanticRouter 55 | 56 | sr = SemanticRouter(encoder=encoder, routes=routes, auto_sync="local") 57 | ``` 58 | 59 | Now we can call the `RouteLayer` with an input query: 60 | 61 | ```python 62 | sr("don't you love politics?") 63 | ``` 64 | 65 | ``` 66 | [Out]: RouteChoice(name='politics', function_call=None, similarity_score=None) 67 | ``` 68 | 69 | The output is a `RouteChoice` object, which contains the name of the route, 70 | the function call (if any), and the similarity score that triggered the route 71 | choice. 72 | 73 | We can try another query: 74 | 75 | ```python 76 | sr("how's the weather today?") 77 | ``` 78 | 79 | ``` 80 | [Out]: RouteChoice(name='chitchat', function_call=None, similarity_score=None) 81 | ``` 82 | 83 | Both are classified accurately, what if we send a query that is unrelated to 84 | our existing Route objects? 85 | 86 | ```python 87 | sr("I'm interested in learning about llama 3") 88 | ``` 89 | 90 | ``` 91 | [Out]: RouteChoice(name=None, function_call=None, similarity_score=None) 92 | ``` 93 | 94 | In this case, the `RouteLayer` is unable to find a route that matches the 95 | input query and so returns a `RouteChoice` with `name=None`. 96 | 97 | We can also retrieve multiple routes with their associated score using 98 | `retrieve_multiple_routes`: 99 | 100 | ```python 101 | sr.retrieve_multiple_routes("Hi! How are you doing in politics??") 102 | ``` 103 | 104 | ``` 105 | [Out]: [RouteChoice(name='politics', function_call=None, similarity_score=0.859), 106 | RouteChoice(name='chitchat', function_call=None, similarity_score=0.835)] 107 | ``` 108 | 109 | If `retrieve_multiple_routes` is called with a query that does not match any 110 | routes, it will return an empty list: 111 | 112 | ```python 113 | sr.retrieve_multiple_routes("I'm interested in learning about llama 3") 114 | ``` 115 | 116 | ``` 117 | [Out]: [] 118 | ``` 119 | 120 | You can find an introductory notebook for the [route layer here](https://github.com/aurelio-labs/semantic-router/blob/main/docs/00-introduction.ipynb). -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "semantic-router" 3 | version = "0.1.11" 4 | description = "Super fast semantic router for AI decision making" 5 | authors = [{ name = "Aurelio AI", email = "hello@aurelio.ai" }] 6 | requires-python = ">=3.9,<3.14" 7 | readme = "README.md" 8 | license = "MIT" 9 | dependencies = [ 10 | "pydantic>=2.10.2,<3", 11 | "numpy>=1.25.2", 12 | "colorlog>=6.8.0,<7", 13 | "pyyaml>=6.0.1,<7", 14 | "aurelio-sdk>=0.0.19", 15 | "colorama>=0.4.6,<0.5", 16 | "regex>=2023.12.25", 17 | "tiktoken>=0.6.0,<1.0.0", 18 | "aiohttp>=3.10.11,<4", 19 | "tornado>=6.4.2,<7", 20 | "urllib3>=1.26,<3", 21 | "litellm>=1.61.3", 22 | "openai>=1.10.0,<2.0.0" 23 | ] 24 | 25 | [project.optional-dependencies] 26 | ollama = ["ollama>=0.1.7"] 27 | local = [ 28 | "transformers>=4.36.2 ; python_version < '3.13'", 29 | "tokenizers>=0.19 ; python_version < '3.13'", 30 | "llama-cpp-python>=0.2.28,<0.2.86 ; python_version < '3.13'", 31 | "sentence-transformers>=5.0.0 ; python_version < '3.13'", 32 | "torch>=2.6.0 ; python_version < '3.13'" 33 | ] 34 | vision = [ 35 | "torchvision>=0.17.0 ; python_version < '3.13'", 36 | "transformers>=4.36.2 ; python_version < '3.13'", 37 | "pillow>=10.2.0,<11.0.0 ; python_version < '3.13'", 38 | "torch>=2.6.0 ; python_version < '3.13'" 39 | ] 40 | pinecone = ["pinecone>=5.0.0,<6.0.0"] 41 | mistralai = ["mistralai>=0.0.12,<0.1.0"] 42 | qdrant = ["qdrant-client>=1.11.1,<2"] 43 | google = ["google-cloud-aiplatform>=1.45.0,<2"] 44 | bedrock = [ 45 | "boto3>=1.34.98,<2", 46 | "botocore>=1.34.110,<2", 47 | ] 48 | postgres = [ 49 | "psycopg[binary]>=3.1.0,<4", 50 | ] 51 | fastembed = ["fastembed>=0.3.0,<0.4 ; python_version < '3.13'"] 52 | docs = ["pydoc-markdown>=4.8.2 ; python_version < '3.12'"] 53 | cohere = ["cohere>=5.9.4,<6.00"] 54 | dev = [ 55 | "ipykernel>=6.25.0,<7", 56 | "ruff>=0.11.2,<0.12", 57 | "pytest~=8.2", 58 | "pytest-mock>=3.12.0,<4", 59 | "pytest-cov>=4.1.0,<5", 60 | "pytest-xdist>=3.5.0,<4", 61 | "pytest-asyncio>=0.24.0,<0.25", 62 | "pytest-timeout", 63 | "mypy>=1.7.1,<2", 64 | "types-pyyaml>=6.0.12.12,<7", 65 | "requests-mock>=1.12.1,<2", 66 | "types-requests>=2.31.0,<3", 67 | "dagger-io>=0.1.1 ; python_version >= '3.11'" 68 | ] 69 | all = [ 70 | "semantic-router[local]", 71 | "semantic-router[pinecone]", 72 | "semantic-router[vision]", 73 | "semantic-router[mistralai]", 74 | "semantic-router[qdrant]", 75 | "semantic-router[google]", 76 | "semantic-router[bedrock]", 77 | "semantic-router[postgres]", 78 | "semantic-router[fastembed]", 79 | "semantic-router[cohere]", 80 | "semantic-router[dev]", 81 | "semantic-router[ollama]", 82 | "pytest-timeout", 83 | ] 84 | 85 | [tool.hatch.build.targets.sdist] 86 | include = ["semantic_router"] 87 | 88 | [tool.hatch.build.targets.wheel] 89 | include = ["semantic_router"] 90 | 91 | [build-system] 92 | requires = ["hatchling"] 93 | build-backend = "hatchling.build" 94 | 95 | [tool.ruff.lint.per-file-ignores] 96 | "*.ipynb" = ["I", "E501", "T201", "F404"] 97 | 98 | [tool.ruff] 99 | line-length = 88 100 | 101 | [tool.ruff.lint] 102 | select = ["E", "F", "I", "T201", "NPY201"] 103 | ignore = ["E501"] 104 | 105 | [tool.mypy] 106 | ignore_missing_imports = true 107 | 108 | # PyTorch indexes 109 | [[tool.uv.index]] 110 | name = "pytorch-cpu" 111 | url = "https://download.pytorch.org/whl/cpu" 112 | explicit = true 113 | 114 | [[tool.uv.index]] 115 | name = "pytorch-nightly" 116 | url = "https://download.pytorch.org/whl/nightly/cpu" 117 | explicit = true 118 | 119 | [tool.uv.sources] 120 | torch = [ 121 | { index = "pytorch-nightly", marker = "python_version >= '3.13' and sys_platform == 'darwin'" }, 122 | { index = "pytorch-cpu", marker = "python_version < '3.13' or sys_platform != 'darwin'" }, 123 | ] 124 | torchvision = [ 125 | { index = "pytorch-nightly", marker = "python_version >= '3.13' and sys_platform == 'darwin'" }, 126 | { index = "pytorch-cpu", marker = "python_version < '3.13' or sys_platform != 'darwin'" }, 127 | ] 128 | -------------------------------------------------------------------------------- /replace.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | 5 | def replace_type_hints(file_path): 6 | with open(file_path, "rb") as file: 7 | file_data = file.read() 8 | 9 | # Decode the file data with error handling 10 | file_data = file_data.decode("utf-8", errors="ignore") 11 | 12 | # Regular expression pattern to find 'Dict[Type1, Type2] | None' and replace with 'Optional[Dict[Type1, Type2]]'. 13 | file_data = re.sub( 14 | r"Dict\[(\w+), (\w+)\]\s*\|\s*None", r"Optional[Dict[\1, \2]]", file_data 15 | ) 16 | 17 | with open(file_path, "w") as file: 18 | file.write(file_data) 19 | 20 | 21 | # Directory path 22 | dir_path = "/Users/jakit/customers/aurelio/semantic-router" 23 | 24 | # Traverse the directory 25 | for root, dirs, files in os.walk(dir_path): 26 | for file in files: 27 | if file.endswith(".py"): 28 | replace_type_hints(os.path.join(root, file)) 29 | -------------------------------------------------------------------------------- /semantic_router/__init__.py: -------------------------------------------------------------------------------- 1 | from semantic_router.route import Route 2 | from semantic_router.routers import HybridRouter, RouterConfig, SemanticRouter 3 | 4 | __all__ = ["SemanticRouter", "HybridRouter", "Route", "RouterConfig"] 5 | 6 | __version__ = "0.1.2" 7 | -------------------------------------------------------------------------------- /semantic_router/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | from semantic_router.encoders.base import DenseEncoder, SparseEncoder # isort: skip 4 | from semantic_router.encoders.aurelio import AurelioSparseEncoder 5 | from semantic_router.encoders.azure_openai import AzureOpenAIEncoder 6 | from semantic_router.encoders.bedrock import BedrockEncoder 7 | from semantic_router.encoders.bm25 import BM25Encoder 8 | from semantic_router.encoders.clip import CLIPEncoder 9 | from semantic_router.encoders.cohere import CohereEncoder 10 | from semantic_router.encoders.fastembed import FastEmbedEncoder 11 | from semantic_router.encoders.google import GoogleEncoder 12 | from semantic_router.encoders.huggingface import HFEndpointEncoder, HuggingFaceEncoder 13 | from semantic_router.encoders.jina import JinaEncoder 14 | from semantic_router.encoders.litellm import LiteLLMEncoder 15 | from semantic_router.encoders.local import LocalEncoder, LocalSparseEncoder 16 | from semantic_router.encoders.mistral import MistralEncoder 17 | from semantic_router.encoders.nvidia_nim import NimEncoder 18 | from semantic_router.encoders.ollama import OllamaEncoder 19 | from semantic_router.encoders.openai import OpenAIEncoder 20 | from semantic_router.encoders.tfidf import TfidfEncoder 21 | from semantic_router.encoders.vit import VitEncoder 22 | from semantic_router.encoders.voyage import VoyageEncoder 23 | from semantic_router.schema import EncoderType, SparseEmbedding 24 | 25 | __all__ = [ 26 | "AurelioSparseEncoder", 27 | "DenseEncoder", 28 | "SparseEncoder", 29 | "AzureOpenAIEncoder", 30 | "CohereEncoder", 31 | "OpenAIEncoder", 32 | "BM25Encoder", 33 | "TfidfEncoder", 34 | "FastEmbedEncoder", 35 | "HuggingFaceEncoder", 36 | "HFEndpointEncoder", 37 | "MistralEncoder", 38 | "VitEncoder", 39 | "CLIPEncoder", 40 | "GoogleEncoder", 41 | "BedrockEncoder", 42 | "LiteLLMEncoder", 43 | "VoyageEncoder", 44 | "JinaEncoder", 45 | "NimEncoder", 46 | "OllamaEncoder", 47 | "LocalEncoder", 48 | "LocalSparseEncoder", 49 | ] 50 | 51 | 52 | class AutoEncoder: 53 | type: EncoderType 54 | name: Optional[str] 55 | model: DenseEncoder | SparseEncoder 56 | 57 | def __init__(self, type: str, name: Optional[str]): 58 | self.type = EncoderType(type) 59 | self.name = name 60 | if self.type == EncoderType.AZURE: 61 | self.model = AzureOpenAIEncoder(name=name) 62 | elif self.type == EncoderType.COHERE: 63 | self.model = CohereEncoder(name=name) 64 | elif self.type == EncoderType.OPENAI: 65 | self.model = OpenAIEncoder(name=name) 66 | elif self.type == EncoderType.AURELIO: 67 | self.model = AurelioSparseEncoder(name=name) 68 | elif self.type == EncoderType.BM25: 69 | if name is None: 70 | name = "bm25" 71 | self.model = BM25Encoder(name=name) 72 | elif self.type == EncoderType.TFIDF: 73 | if name is None: 74 | name = "tfidf" 75 | self.model = TfidfEncoder(name=name) 76 | elif self.type == EncoderType.FASTEMBED: 77 | self.model = FastEmbedEncoder(name=name) 78 | elif self.type == EncoderType.HUGGINGFACE: 79 | self.model = HuggingFaceEncoder(name=name) 80 | elif self.type == EncoderType.MISTRAL: 81 | self.model = MistralEncoder(name=name) 82 | elif self.type == EncoderType.VOYAGE: 83 | self.model = VoyageEncoder(name=name) 84 | elif self.type == EncoderType.JINA: 85 | self.model = JinaEncoder(name=name) 86 | elif self.type == EncoderType.NIM: 87 | self.model = NimEncoder(name=name) 88 | elif self.type == EncoderType.VIT: 89 | self.model = VitEncoder(name=name) 90 | elif self.type == EncoderType.CLIP: 91 | self.model = CLIPEncoder(name=name) 92 | elif self.type == EncoderType.GOOGLE: 93 | self.model = GoogleEncoder(name=name) 94 | elif self.type == EncoderType.BEDROCK: 95 | self.model = BedrockEncoder(name=name) # type: ignore 96 | elif self.type == EncoderType.LITELLM: 97 | self.model = LiteLLMEncoder(name=name) 98 | elif self.type == EncoderType.OLLAMA: 99 | self.model = OllamaEncoder(name=name) 100 | elif self.type == EncoderType.LOCAL: 101 | self.model = LocalEncoder(name=name) 102 | else: 103 | raise ValueError(f"Encoder type '{type}' not supported") 104 | 105 | def __call__(self, texts: List[str]) -> List[List[float]] | List[SparseEmbedding]: 106 | return self.model(texts) 107 | -------------------------------------------------------------------------------- /semantic_router/encoders/aurelio.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, List, Optional 3 | 4 | from aurelio_sdk import AsyncAurelioClient, AurelioClient, EmbeddingResponse 5 | from pydantic import Field 6 | 7 | from semantic_router.encoders.base import AsymmetricSparseMixin, SparseEncoder 8 | from semantic_router.schema import SparseEmbedding 9 | 10 | 11 | class AurelioSparseEncoder(SparseEncoder, AsymmetricSparseMixin): 12 | """Sparse encoder using Aurelio Platform's embedding API. Requires an API key from 13 | https://platform.aurelio.ai 14 | """ 15 | 16 | model: Optional[Any] = None 17 | client: AurelioClient = Field(default_factory=AurelioClient, exclude=True) 18 | async_client: AsyncAurelioClient = Field( 19 | default_factory=AsyncAurelioClient, exclude=True 20 | ) 21 | type: str = "sparse" 22 | 23 | def __init__( 24 | self, 25 | name: str | None = None, 26 | api_key: Optional[str] = None, 27 | ): 28 | """Initialize the AurelioSparseEncoder. 29 | 30 | :param name: The name of the model to use. 31 | :type name: str | None 32 | :param api_key: The API key to use. 33 | :type api_key: str | None 34 | """ 35 | if name is None: 36 | name = "bm25" 37 | super().__init__(name=name) 38 | if api_key is None: 39 | api_key = os.getenv("AURELIO_API_KEY") 40 | if api_key is None: 41 | raise ValueError("AURELIO_API_KEY environment variable is not set.") 42 | self.client = AurelioClient(api_key=api_key) 43 | self.async_client = AsyncAurelioClient(api_key=api_key) 44 | 45 | def __call__(self, docs: list[str]) -> list[SparseEmbedding]: 46 | """Encode a list of queries using the Aurelio Platform embedding API. Documents 47 | must be strings, sparse encoders do not support other types. 48 | """ 49 | return self.encode_queries(docs) 50 | 51 | def encode_queries(self, docs: List[str]) -> List[SparseEmbedding]: 52 | res: EmbeddingResponse = self.client.embedding( 53 | input=docs, model=self.name, input_type="queries" 54 | ) 55 | embeds = [SparseEmbedding.from_aurelio(r.embedding) for r in res.data] 56 | return embeds 57 | 58 | def encode_documents(self, docs: List[str]) -> List[SparseEmbedding]: 59 | res: EmbeddingResponse = self.client.embedding( 60 | input=docs, model=self.name, input_type="documents" 61 | ) 62 | embeds = [SparseEmbedding.from_aurelio(r.embedding) for r in res.data] 63 | return embeds 64 | 65 | async def aencode_queries(self, docs: List[str]) -> list[SparseEmbedding]: 66 | res: EmbeddingResponse = await self.async_client.embedding( 67 | input=docs, model=self.name, input_type="queries" 68 | ) 69 | embeds = [SparseEmbedding.from_aurelio(r.embedding) for r in res.data] 70 | return embeds 71 | 72 | async def aencode_documents(self, docs: List[str]) -> list[SparseEmbedding]: 73 | res: EmbeddingResponse = await self.async_client.embedding( 74 | input=docs, model=self.name, input_type="documents" 75 | ) 76 | embeds = [SparseEmbedding.from_aurelio(r.embedding) for r in res.data] 77 | return embeds 78 | 79 | async def acall(self, docs: list[str]) -> list[SparseEmbedding]: 80 | """Asynchronously encode a list of documents using the Aurelio Platform 81 | embedding API. Documents must be strings, sparse encoders do not support other 82 | types. 83 | 84 | :param docs: The documents to encode. 85 | :type docs: list[str] 86 | :param input_type: 87 | :type semantic_router.encoders.encode_input_type.EncodeInputType 88 | :return: The encoded documents. 89 | :rtype: list[SparseEmbedding] 90 | """ 91 | return await self.aencode_queries(docs) 92 | 93 | def fit(self, docs: List[str]): 94 | """Fit the encoder to a list of documents. AurelioSparseEncoder does not support 95 | fit yet. 96 | 97 | :param docs: The documents to fit the encoder to. 98 | :type docs: list[str] 99 | """ 100 | raise NotImplementedError("AurelioSparseEncoder does not support fit.") 101 | -------------------------------------------------------------------------------- /semantic_router/encoders/cohere.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any 3 | 4 | import litellm 5 | from pydantic import PrivateAttr 6 | from typing_extensions import deprecated 7 | 8 | from semantic_router.encoders.litellm import LiteLLMEncoder, litellm_to_list 9 | from semantic_router.utils.defaults import EncoderDefault 10 | 11 | 12 | class CohereEncoder(LiteLLMEncoder): 13 | """Dense encoder that uses Cohere API to embed documents. Supports text only. Requires 14 | a Cohere API key from https://dashboard.cohere.com/api-keys. 15 | """ 16 | 17 | _client: Any = PrivateAttr() # TODO: deprecated, to remove in v0.2.0 18 | _async_client: Any = PrivateAttr() # TODO: deprecated, to remove in v0.2.0 19 | _embed_type: Any = PrivateAttr() # TODO: deprecated, to remove in v0.2.0 20 | type: str = "cohere" 21 | 22 | def __init__( 23 | self, 24 | name: str | None = None, 25 | cohere_api_key: str | None = None, # TODO: rename to api_key in v0.2.0 26 | score_threshold: float = 0.3, 27 | ): 28 | """Initialize the Cohere encoder. 29 | 30 | :param name: The name of the embedding model to use such as "embed-english-v3.0" or 31 | "embed-multilingual-v3.0". 32 | :type name: str 33 | :param cohere_api_key: The API key for the Cohere client, can also 34 | be set via the COHERE_API_KEY environment variable. 35 | :type cohere_api_key: str 36 | :param score_threshold: The threshold for the score of the embedding. 37 | :type score_threshold: float 38 | """ 39 | # get default model name if none provided and convert to litellm format 40 | if name is None: 41 | name = f"cohere/{EncoderDefault.COHERE.value['embedding_model']}" 42 | elif not name.startswith("cohere/"): 43 | name = f"cohere/{name}" 44 | super().__init__( 45 | name=name, 46 | score_threshold=score_threshold, 47 | api_key=cohere_api_key, 48 | ) 49 | self._client = None # TODO: deprecated, to remove in v0.2.0 50 | self._async_client = None # TODO: deprecated, to remove in v0.2.0 51 | 52 | # TODO: deprecated, to remove in v0.2.0 53 | @deprecated("_initialize_client method no longer required") 54 | def _initialize_client(self, cohere_api_key: str | None = None): 55 | """Initializes the Cohere client. 56 | 57 | :param cohere_api_key: The API key for the Cohere client, can also 58 | be set via the COHERE_API_KEY environment variable. 59 | :type cohere_api_key: str 60 | :return: An instance of the Cohere client. 61 | :rtype: cohere.Client 62 | """ 63 | cohere_api_key = cohere_api_key or os.getenv("COHERE_API_KEY") 64 | if cohere_api_key is None: 65 | raise ValueError("Cohere API key cannot be 'None'.") 66 | return None, None 67 | 68 | def encode_queries(self, docs: list[str], **kwargs) -> list[list[float]]: 69 | try: 70 | embeds = litellm.embedding( 71 | input=docs, 72 | input_type="search_query", 73 | model=f"{self.type}/{self.name}", 74 | **kwargs, 75 | ) 76 | return litellm_to_list(embeds) 77 | except Exception as e: 78 | raise ValueError(f"Cohere API call failed. Error: {e}") from e 79 | 80 | def encode_documents(self, docs: list[str], **kwargs) -> list[list[float]]: 81 | try: 82 | embeds = litellm.embedding( 83 | input=docs, 84 | input_type="search_document", 85 | model=f"{self.type}/{self.name}", 86 | **kwargs, 87 | ) 88 | return litellm_to_list(embeds) 89 | except Exception as e: 90 | raise ValueError(f"Cohere API call failed. Error: {e}") from e 91 | 92 | async def aencode_queries(self, docs: list[str], **kwargs) -> list[list[float]]: 93 | try: 94 | embeds = await litellm.aembedding( 95 | input=docs, 96 | input_type="search_query", 97 | model=f"{self.type}/{self.name}", 98 | **kwargs, 99 | ) 100 | return litellm_to_list(embeds) 101 | except Exception as e: 102 | raise ValueError(f"Cohere API call failed. Error: {e}") from e 103 | 104 | async def aencode_documents(self, docs: list[str], **kwargs) -> list[list[float]]: 105 | try: 106 | embeds = await litellm.aembedding( 107 | input=docs, 108 | input_type="search_document", 109 | model=f"{self.type}/{self.name}", 110 | **kwargs, 111 | ) 112 | return litellm_to_list(embeds) 113 | except Exception as e: 114 | raise ValueError(f"Cohere API call failed. Error: {e}") from e 115 | -------------------------------------------------------------------------------- /semantic_router/encoders/encode_input_type.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | EncodeInputType = Literal["queries", "documents"] 4 | -------------------------------------------------------------------------------- /semantic_router/encoders/fastembed.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Optional 2 | 3 | import numpy as np 4 | from pydantic import PrivateAttr 5 | 6 | from semantic_router.encoders import DenseEncoder 7 | 8 | 9 | class FastEmbedEncoder(DenseEncoder): 10 | """Dense encoder that uses local FastEmbed to embed documents. Supports text only. 11 | Requires the fastembed package which can be installed with `pip install 'semantic-router[fastembed]'` 12 | 13 | :param name: The name of the embedding model to use. 14 | :param max_length: The maximum length of the input text. 15 | :param cache_dir: The directory to cache the embedding model. 16 | :param threads: The number of threads to use for the embedding. 17 | """ 18 | 19 | type: str = "fastembed" 20 | name: str = "BAAI/bge-small-en-v1.5" 21 | max_length: int = 512 22 | cache_dir: Optional[str] = None 23 | threads: Optional[int] = None 24 | _client: Any = PrivateAttr() 25 | 26 | def __init__(self, score_threshold: float = 0.5, **data): 27 | """Initialize the FastEmbed encoder. 28 | 29 | :param score_threshold: The threshold for the score of the embedding. 30 | :type score_threshold: float 31 | """ 32 | # TODO default score_threshold not thoroughly tested, should optimize 33 | super().__init__(score_threshold=score_threshold, **data) 34 | self._client = self._initialize_client() 35 | 36 | def _initialize_client(self): 37 | """Initialize the FastEmbed library. Requires the fastembed package.""" 38 | try: 39 | from fastembed import TextEmbedding 40 | except ImportError: 41 | raise ImportError( 42 | "Please install fastembed to use FastEmbedEncoder. " 43 | "You can install it with: " 44 | "`pip install 'semantic-router[fastembed]'`" 45 | ) 46 | 47 | embedding_args = { 48 | "model_name": self.name, 49 | "max_length": self.max_length, 50 | "cache_dir": self.cache_dir, 51 | "threads": self.threads, 52 | } 53 | 54 | embedding_args = {k: v for k, v in embedding_args.items() if v is not None} 55 | 56 | embedding = TextEmbedding(**embedding_args) 57 | return embedding 58 | 59 | def __call__(self, docs: List[str]) -> List[List[float]]: 60 | """Embed a list of documents. Supports text only. 61 | 62 | :param docs: The documents to embed. 63 | :type docs: List[str] 64 | :raise ValueError: If the embedding fails. 65 | :return: The vector embeddings of the documents. 66 | :rtype: List[List[float]] 67 | """ 68 | try: 69 | embeds: List[np.ndarray] = list(self._client.embed(docs)) 70 | embeddings: List[List[float]] = [e.tolist() for e in embeds] 71 | return embeddings 72 | except Exception as e: 73 | raise ValueError(f"FastEmbed embed failed. Error: {e}") from e 74 | -------------------------------------------------------------------------------- /semantic_router/encoders/jina.py: -------------------------------------------------------------------------------- 1 | """This file contains the JinaEncoder class which is used to encode text using Jina""" 2 | 3 | from semantic_router.encoders.litellm import LiteLLMEncoder 4 | from semantic_router.utils.defaults import EncoderDefault 5 | 6 | 7 | class JinaEncoder(LiteLLMEncoder): 8 | """Class to encode text using Jina. Requires a Jina API key from 9 | https://jina.ai/api-keys/""" 10 | 11 | type: str = "jina" 12 | 13 | def __init__( 14 | self, 15 | name: str | None = None, 16 | api_key: str | None = None, 17 | score_threshold: float = 0.4, 18 | ): 19 | """Initialize the JinaEncoder. 20 | 21 | :param name: The name of the embedding model to use such as "jina-embeddings-v3". 22 | :param jina_api_key: The Jina API key. 23 | :type jina_api_key: str 24 | """ 25 | 26 | if name is None: 27 | name = f"jina_ai/{EncoderDefault.JINA.value['embedding_model']}" 28 | elif not name.startswith("jina_ai/"): 29 | name = f"jina_ai/{name}" 30 | super().__init__( 31 | name=name, 32 | score_threshold=score_threshold, 33 | api_key=api_key, 34 | ) 35 | -------------------------------------------------------------------------------- /semantic_router/encoders/litellm.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any 3 | 4 | import litellm 5 | 6 | from semantic_router.encoders import DenseEncoder 7 | from semantic_router.encoders.base import AsymmetricDenseMixin 8 | from semantic_router.utils.defaults import EncoderDefault 9 | 10 | 11 | def litellm_to_list(embeds: litellm.EmbeddingResponse) -> list[list[float]]: 12 | """Convert a LiteLLM embedding response to a list of embeddings. 13 | 14 | :param embeds: The LiteLLM embedding response. 15 | :return: A list of embeddings. 16 | """ 17 | if ( 18 | not embeds 19 | or not isinstance(embeds, litellm.EmbeddingResponse) 20 | or not embeds.data 21 | ): 22 | raise ValueError("No embeddings found in LiteLLM embedding response.") 23 | return [x["embedding"] for x in embeds.data] 24 | 25 | 26 | class LiteLLMEncoder(DenseEncoder, AsymmetricDenseMixin): 27 | """LiteLLM encoder class for generating embeddings using LiteLLM. 28 | 29 | The LiteLLMEncoder class is a subclass of DenseEncoder and utilizes the LiteLLM SDK 30 | to generate embeddings for given documents. It supports all encoders supported by LiteLLM 31 | and supports customization of the score threshold for filtering or processing the embeddings. 32 | """ 33 | 34 | type: str = "litellm" 35 | 36 | def __init__( 37 | self, 38 | name: str | None = None, 39 | score_threshold: float | None = None, 40 | api_key: str | None = None, 41 | ): 42 | """Initialize the LiteLLMEncoder. 43 | 44 | :param name: The name of the embedding model to use. Must use LiteLLM naming 45 | convention (e.g. "openai/text-embedding-3-small" or "mistral/mistral-embed"). 46 | :type name: str 47 | :param score_threshold: The score threshold for the embeddings. 48 | :type score_threshold: float 49 | """ 50 | if name is None: 51 | # defaults to default openai model if none provided 52 | name = "openai/" + EncoderDefault.OPENAI.value["embedding_model"] 53 | super().__init__( 54 | name=name, 55 | score_threshold=score_threshold if score_threshold is not None else 0.3, 56 | ) 57 | self.type, self.name = self.name.split("/", 1) 58 | if api_key is None: 59 | api_key = os.getenv(self.type.upper() + "_API_KEY") 60 | if api_key is None: 61 | raise ValueError( 62 | "Expected API key via `api_key` parameter or `{self.type.upper()}_API_KEY` " 63 | "environment variable." 64 | ) 65 | os.environ[self.type.upper() + "_API_KEY"] = api_key 66 | 67 | def __call__(self, docs: list[Any], **kwargs) -> list[list[float]]: 68 | """Encode a list of text documents into embeddings using LiteLLM. 69 | 70 | :param docs: List of text documents to encode. 71 | :return: List of embeddings for each document.""" 72 | return self.encode_queries(docs, **kwargs) 73 | 74 | async def acall(self, docs: list[Any], **kwargs) -> list[list[float]]: 75 | """Encode a list of documents into embeddings using LiteLLM asynchronously. 76 | 77 | :param docs: List of documents to encode. 78 | :return: List of embeddings for each document.""" 79 | return await self.aencode_queries(docs, **kwargs) 80 | 81 | def encode_queries(self, docs: list[str], **kwargs) -> list[list[float]]: 82 | try: 83 | embeds = litellm.embedding( 84 | input=docs, model=f"{self.type}/{self.name}", **kwargs 85 | ) 86 | return litellm_to_list(embeds) 87 | except Exception as e: 88 | raise ValueError( 89 | f"{self.type.capitalize()} API call failed. Error: {e}" 90 | ) from e 91 | 92 | def encode_documents(self, docs: list[str], **kwargs) -> list[list[float]]: 93 | try: 94 | embeds = litellm.embedding( 95 | input=docs, model=f"{self.type}/{self.name}", **kwargs 96 | ) 97 | return litellm_to_list(embeds) 98 | except Exception as e: 99 | raise ValueError( 100 | f"{self.type.capitalize()} API call failed. Error: {e}" 101 | ) from e 102 | 103 | async def aencode_queries(self, docs: list[str], **kwargs) -> list[list[float]]: 104 | try: 105 | embeds = await litellm.aembedding( 106 | input=docs, model=f"{self.type}/{self.name}", **kwargs 107 | ) 108 | return litellm_to_list(embeds) 109 | except Exception as e: 110 | raise ValueError( 111 | f"{self.type.capitalize()} API call failed. Error: {e}" 112 | ) from e 113 | 114 | async def aencode_documents(self, docs: list[str], **kwargs) -> list[list[float]]: 115 | try: 116 | embeds = await litellm.aembedding( 117 | input=docs, model=f"{self.type}/{self.name}", **kwargs 118 | ) 119 | return litellm_to_list(embeds) 120 | except Exception as e: 121 | raise ValueError( 122 | f"{self.type.capitalize()} API call failed. Error: {e}" 123 | ) from e 124 | -------------------------------------------------------------------------------- /semantic_router/encoders/local.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Optional 2 | 3 | from pydantic import PrivateAttr 4 | 5 | from semantic_router.encoders.base import DenseEncoder, SparseEncoder 6 | from semantic_router.schema import SparseEmbedding 7 | 8 | 9 | class LocalEncoder(DenseEncoder): 10 | """Local encoder using sentence-transformers for efficient local embeddings.""" 11 | 12 | name: str = "BAAI/bge-small-en-v1.5" 13 | type: str = "local" 14 | device: Optional[str] = None 15 | normalize_embeddings: bool = True 16 | batch_size: int = 32 17 | _model: Any = PrivateAttr() 18 | 19 | def __init__(self, **kwargs): 20 | super().__init__(**kwargs) 21 | try: 22 | from sentence_transformers import SentenceTransformer 23 | except ImportError: 24 | raise ImportError( 25 | "Please install sentence-transformers to use LocalEncoder. " 26 | "You can install it with: `pip install semantic-router[local]`" 27 | ) 28 | self._model = SentenceTransformer(self.name) 29 | if self.device: 30 | self._model.to(self.device) 31 | else: 32 | # Auto-detect device 33 | import torch 34 | 35 | if torch.cuda.is_available(): 36 | self.device = "cuda" 37 | elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): 38 | self.device = "mps" 39 | else: 40 | self.device = "cpu" 41 | self._model.to(self.device) 42 | 43 | def __call__(self, docs: List[str]) -> List[List[float]]: 44 | result = self._model.encode( 45 | docs, 46 | batch_size=self.batch_size, 47 | normalize_embeddings=self.normalize_embeddings, 48 | device=self.device, 49 | ) 50 | return result.tolist() # type: ignore[attr-defined] 51 | 52 | 53 | class LocalSparseEncoder(SparseEncoder): 54 | """Local sparse encoder using sentence-transformers' SparseEncoder (e.g., SPLADE, CSR) for efficient local sparse embeddings.""" 55 | 56 | name: str = "naver/splade-v3" 57 | type: str = "sparse_local" 58 | device: Optional[str] = None 59 | batch_size: int = 32 60 | _model: Any = PrivateAttr() 61 | 62 | def __init__(self, **kwargs): 63 | super().__init__(**kwargs) 64 | try: 65 | from sentence_transformers import SparseEncoder as STSparseEncoder 66 | except ImportError: 67 | raise ImportError( 68 | "Please install sentence-transformers >=v5 to use SparseSentenceTransformerEncoder. " 69 | "You can install it with: `pip install sentence-transformers`" 70 | ) 71 | self._model = STSparseEncoder(self.name) 72 | if self.device: 73 | self._model.to(self.device) 74 | else: 75 | # Auto-detect device 76 | import torch 77 | 78 | if torch.cuda.is_available(): 79 | self.device = "cuda" 80 | elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): 81 | self.device = "mps" 82 | else: 83 | self.device = "cpu" 84 | self._model.to(self.device) 85 | 86 | def __call__(self, docs: List[str]) -> List[SparseEmbedding]: 87 | # The model.encode returns a numpy array (batch, vocab_size) sparse matrix 88 | sparse_embeddings = self._model.encode(docs, batch_size=self.batch_size) 89 | # Convert to List[SparseEmbedding] using the helper from base.py 90 | return self._array_to_sparse_embeddings(sparse_embeddings) 91 | -------------------------------------------------------------------------------- /semantic_router/encoders/mistral.py: -------------------------------------------------------------------------------- 1 | """This file contains the MistralEncoder class which is used to encode text using MistralAI""" 2 | 3 | import os 4 | from typing import Any 5 | 6 | from pydantic import PrivateAttr 7 | from typing_extensions import deprecated 8 | 9 | from semantic_router.encoders.litellm import LiteLLMEncoder 10 | from semantic_router.utils.defaults import EncoderDefault 11 | 12 | 13 | class MistralEncoder(LiteLLMEncoder): 14 | """Class to encode text using MistralAI. Requires a MistralAI API key from 15 | https://console.mistral.ai/api-keys/""" 16 | 17 | _client: Any = PrivateAttr() # TODO: deprecated, to remove in v0.2.0 18 | _mistralai: Any = PrivateAttr() # TODO: deprecated, to remove in v0.2.0 19 | type: str = "mistral" 20 | 21 | def __init__( 22 | self, 23 | name: str | None = None, 24 | mistralai_api_key: str | None = None, # TODO: rename to api_key in v0.2.0 25 | score_threshold: float = 0.4, 26 | ): 27 | """Initialize the MistralEncoder. 28 | 29 | :param name: The name of the embedding model to use such as "mistral-embed". 30 | :type name: str 31 | :param mistralai_api_key: The MistralAI API key. 32 | :type mistralai_api_key: str 33 | :param score_threshold: The score threshold for the embeddings. 34 | """ 35 | # get default model name if none provided and convert to litellm format 36 | if name is None: 37 | name = f"mistral/{EncoderDefault.MISTRAL.value['embedding_model']}" 38 | elif not name.startswith("mistral/"): 39 | name = f"mistral/{name}" 40 | if mistralai_api_key is None: 41 | mistralai_api_key = os.getenv("MISTRALAI_API_KEY") 42 | if mistralai_api_key is None: 43 | mistralai_api_key = os.getenv("MISTRAL_API_KEY") 44 | super().__init__( 45 | name=name, 46 | score_threshold=score_threshold, 47 | api_key=mistralai_api_key, 48 | ) 49 | 50 | # TODO: deprecated, to remove in v0.2.0 51 | @deprecated("_initialize_client method no longer required") 52 | def _initialize_client(self, api_key): 53 | """Initialize the MistralAI client. 54 | 55 | :param api_key: The MistralAI API key. 56 | :type api_key: str 57 | :return: None 58 | :rtype: None 59 | """ 60 | api_key = ( 61 | api_key or os.getenv("MISTRALAI_API_KEY") or os.getenv("MISTRAL_API_KEY") 62 | ) 63 | if api_key is None: 64 | raise ValueError("Mistral API key not provided") 65 | return None 66 | -------------------------------------------------------------------------------- /semantic_router/encoders/nvidia_nim.py: -------------------------------------------------------------------------------- 1 | """This file contains the NimEncoder class which is used to encode text using Nim""" 2 | 3 | import litellm 4 | 5 | from semantic_router.encoders.litellm import LiteLLMEncoder, litellm_to_list 6 | from semantic_router.utils.defaults import EncoderDefault 7 | 8 | 9 | class NimEncoder(LiteLLMEncoder): 10 | """Class to encode text using Nvidia NIM. Requires a Nim API key from 11 | https://nim.ai/api-keys/""" 12 | 13 | type: str = "nvidia_nim" 14 | 15 | def __init__( 16 | self, 17 | name: str | None = None, 18 | api_key: str | None = None, 19 | score_threshold: float = 0.4, 20 | ): 21 | """Initialize the NimEncoder. 22 | 23 | :param name: The name of the embedding model to use such as "nv-embedqa-e5-v5". 24 | :type name: str 25 | :param nim_api_key: The Nim API key. 26 | :type nim_api_key: str 27 | """ 28 | 29 | if name is None: 30 | name = f"nvidia_nim/{EncoderDefault.NVIDIA_NIM.value['embedding_model']}" 31 | elif not name.startswith("nvidia_nim/"): 32 | name = f"nvidia_nim/{name}" 33 | super().__init__( 34 | name=name, 35 | score_threshold=score_threshold, 36 | api_key=api_key, 37 | ) 38 | 39 | def encode_queries(self, docs: list[str], **kwargs) -> list[list[float]]: 40 | try: 41 | embeds = litellm.embedding( 42 | input=docs, 43 | input_type="passage", 44 | model=f"{self.type}/{self.name}", 45 | **kwargs, 46 | ) 47 | return litellm_to_list(embeds) 48 | except Exception as e: 49 | raise ValueError(f"Nim API call failed. Error: {e}") from e 50 | 51 | def encode_documents(self, docs: list[str], **kwargs) -> list[list[float]]: 52 | try: 53 | embeds = litellm.embedding( 54 | input=docs, 55 | input_type="passage", 56 | model=f"{self.type}/{self.name}", 57 | **kwargs, 58 | ) 59 | return litellm_to_list(embeds) 60 | except Exception as e: 61 | raise ValueError(f"Nim API call failed. Error: {e}") from e 62 | 63 | async def aencode_queries(self, docs: list[str], **kwargs) -> list[list[float]]: 64 | try: 65 | embeds = await litellm.aembedding( 66 | input=docs, 67 | input_type="passage", 68 | model=f"{self.type}/{self.name}", 69 | **kwargs, 70 | ) 71 | return litellm_to_list(embeds) 72 | except Exception as e: 73 | raise ValueError(f"Nim API call failed. Error: {e}") from e 74 | 75 | async def aencode_documents(self, docs: list[str], **kwargs) -> list[list[float]]: 76 | try: 77 | embeds = await litellm.aembedding( 78 | input=docs, 79 | input_type="passage", 80 | model=f"{self.type}/{self.name}", 81 | **kwargs, 82 | ) 83 | return litellm_to_list(embeds) 84 | except Exception as e: 85 | raise ValueError(f"Nim API call failed. Error: {e}") from e 86 | -------------------------------------------------------------------------------- /semantic_router/encoders/ollama.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, List, Optional 3 | 4 | from semantic_router.encoders import DenseEncoder 5 | from semantic_router.utils.defaults import EncoderDefault 6 | 7 | 8 | class OllamaEncoder(DenseEncoder): 9 | """OllamaEncoder class for generating embeddings using OLLAMA. 10 | 11 | https://ollama.com/search?c=embedding 12 | 13 | Example usage: 14 | 15 | ```python 16 | from semantic_router.encoders.ollama import OllamaEncoder 17 | 18 | encoder = OllamaEncoder(base_url="http://localhost:11434") 19 | embeddings = encoder(["document1", "document2"]) 20 | ``` 21 | 22 | Attributes: 23 | client: An instance of the TextEmbeddingModel client. 24 | type: The type of the encoder, which is "ollama". 25 | """ 26 | 27 | client: Optional[Any] = None 28 | type: str = "ollama" 29 | 30 | def __init__( 31 | self, 32 | name: Optional[str] = None, 33 | score_threshold: float = 0.5, 34 | base_url: str | None = None, 35 | ): 36 | """Initializes the OllamaEncoder. 37 | 38 | :param model_name: The name of the pre-trained model to use for embedding. 39 | If not provided, the default model specified in EncoderDefault will 40 | be used. 41 | :type model_name: str 42 | :param score_threshold: The threshold for similarity scores. 43 | :type score_threshold: float 44 | :param base_url: The API endpoint for OLLAMA. 45 | If not provided, it will be retrieved from the `OLLAMA_BASE_URL` environment variable. 46 | :type base_url: str 47 | 48 | :raise ValueError: If the hosted base url is not provided properly or if the ollama 49 | client fails to initialize. 50 | """ 51 | if name is None: 52 | name = EncoderDefault.OLLAMA.value["embedding_model"] 53 | 54 | super().__init__(name=name, score_threshold=score_threshold) 55 | if base_url is None: 56 | base_url = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434") 57 | self.client = self._initialize_client(base_url=base_url) 58 | 59 | def _initialize_client(self, base_url: str): 60 | """Initializes the Google AI Platform client. 61 | 62 | :param base_url: hosted URL of ollama. 63 | :return: An instance of the TextEmbeddingModel client. 64 | :rtype: TextEmbeddingModel 65 | :raise ImportError: If the required ollama library is not installed. 66 | :raise ValueError: If the hosted base url is not provided properly or if the ollama 67 | client fails to initialize. 68 | """ 69 | try: 70 | from ollama import Client 71 | except ImportError: 72 | raise ImportError( 73 | "The 'ollama' package is not installed. Install it with: pip install 'semantic-router[ollama]'" 74 | ) 75 | 76 | client: Client = Client(host=base_url) 77 | return client 78 | 79 | def __call__(self, docs: List[str]) -> List[List[float]]: 80 | """Generates embeddings for the given documents. 81 | 82 | :param docs: A list of strings representing the documents to embed. 83 | :type docs: List[str] 84 | :return: A list of lists, where each inner list contains the embedding values for a 85 | document. 86 | :rtype: List[List[float]] 87 | :raise ValueError: If the Google AI Platform client is not initialized or if the 88 | API call fails. 89 | """ 90 | if self.client is None: 91 | raise ValueError("OLLAMA Platform client is not initialized.") 92 | try: 93 | return self.client.embed(model=self.name, input=docs).embeddings 94 | except Exception as e: 95 | raise ValueError(f"OLLAMA API call failed. Error: {e}") from e 96 | -------------------------------------------------------------------------------- /semantic_router/encoders/tfidf.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import string 3 | from collections import Counter 4 | from typing import Dict, List 5 | 6 | import numpy as np 7 | 8 | from semantic_router.encoders import SparseEncoder 9 | from semantic_router.encoders.base import FittableMixin 10 | from semantic_router.route import Route 11 | from semantic_router.schema import SparseEmbedding 12 | 13 | 14 | class TfidfEncoder(SparseEncoder, FittableMixin): 15 | idf: np.ndarray = np.array([]) 16 | # TODO: add option to use default params like with BM25Encoder 17 | word_index: Dict = {} 18 | 19 | def __init__(self, name: str | None = None): 20 | if name is None: 21 | name = "tfidf" 22 | super().__init__(name=name) 23 | self.word_index = {} 24 | self.idf = np.array([]) 25 | 26 | def __call__(self, docs: List[str]) -> list[SparseEmbedding]: 27 | if len(self.word_index) == 0 or self.idf.size == 0: 28 | raise ValueError("Vectorizer is not initialized.") 29 | if len(docs) == 0: 30 | raise ValueError("No documents to encode.") 31 | 32 | docs = [self._preprocess(doc) for doc in docs] 33 | tf = self._compute_tf(docs) 34 | tfidf = tf * self.idf 35 | return self._array_to_sparse_embeddings(tfidf) 36 | 37 | async def acall(self, docs: List[str]) -> List[SparseEmbedding]: 38 | return await asyncio.to_thread(lambda: self.__call__(docs)) 39 | 40 | def fit(self, routes: List[Route]): 41 | """Trains the encoder weights on the provided routes. 42 | 43 | :param routes: List of routes to train the encoder on. 44 | :type routes: List[Route] 45 | """ 46 | self._fit_validate(routes=routes) 47 | docs = [] 48 | for route in routes: 49 | for doc in route.utterances: 50 | docs.append(self._preprocess(doc)) # type: ignore 51 | self.word_index = self._build_word_index(docs) 52 | if len(self.word_index) == 0: 53 | raise ValueError(f"Too little data to fit {self.__class__.__name__}.") 54 | self.idf = self._compute_idf(docs) 55 | 56 | def _fit_validate(self, routes: List[Route]): 57 | if not isinstance(routes, list) or not isinstance(routes[0], Route): 58 | raise TypeError("`routes` parameter must be a list of Route objects.") 59 | 60 | def _build_word_index(self, docs: List[str]) -> Dict: 61 | words = set() 62 | for doc in docs: 63 | for word in doc.split(): 64 | words.add(word) 65 | word_index = {word: i for i, word in enumerate(words)} 66 | return word_index 67 | 68 | def _compute_tf(self, docs: List[str]) -> np.ndarray: 69 | if len(self.word_index) == 0: 70 | raise ValueError("Word index is not initialized.") 71 | tf = np.zeros((len(docs), len(self.word_index))) 72 | for i, doc in enumerate(docs): 73 | word_counts = Counter(doc.split()) 74 | for word, count in word_counts.items(): 75 | if word in self.word_index: 76 | tf[i, self.word_index[word]] = count 77 | # L2 normalization 78 | tf = tf / np.linalg.norm(tf, axis=1, keepdims=True) 79 | return tf 80 | 81 | def _compute_idf(self, docs: List[str]) -> np.ndarray: 82 | if len(self.word_index) == 0: 83 | raise ValueError("Word index is not initialized.") 84 | idf = np.zeros(len(self.word_index)) 85 | for doc in docs: 86 | words = set(doc.split()) 87 | for word in words: 88 | if word in self.word_index: 89 | idf[self.word_index[word]] += 1 90 | idf = np.log(len(docs) / (idf + 1)) 91 | return idf 92 | 93 | def _preprocess(self, doc: str) -> str: 94 | lowercased_doc = doc.lower() 95 | no_punctuation_doc = lowercased_doc.translate( 96 | str.maketrans("", "", string.punctuation) 97 | ) 98 | return no_punctuation_doc 99 | -------------------------------------------------------------------------------- /semantic_router/encoders/vit.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional 2 | 3 | from pydantic import PrivateAttr 4 | 5 | from semantic_router.encoders import DenseEncoder 6 | 7 | 8 | class VitEncoder(DenseEncoder): 9 | """Encoder for Vision Transformer models. 10 | 11 | This class provides functionality to encode images using a Vision Transformer 12 | model via Hugging Face. It supports various image processing and model initialization 13 | options. 14 | """ 15 | 16 | name: str = "google/vit-base-patch16-224" 17 | type: str = "huggingface" 18 | processor_kwargs: Dict = {} 19 | model_kwargs: Dict = {} 20 | device: Optional[str] = None 21 | _processor: Any = PrivateAttr() 22 | _model: Any = PrivateAttr() 23 | _torch: Any = PrivateAttr() 24 | _T: Any = PrivateAttr() 25 | _Image: Any = PrivateAttr() 26 | 27 | def __init__(self, **data): 28 | """Initialize the VitEncoder. 29 | 30 | :param **data: Additional keyword arguments for the encoder. 31 | :type **data: dict 32 | """ 33 | if data.get("score_threshold") is None: 34 | data["score_threshold"] = 0.5 35 | super().__init__(**data) 36 | self._processor, self._model = self._initialize_hf_model() 37 | 38 | def _initialize_hf_model(self): 39 | """Initialize the Hugging Face model. 40 | 41 | :return: The processor and model. 42 | :rtype: tuple 43 | """ 44 | try: 45 | from transformers import ViTImageProcessor, ViTModel 46 | except ImportError: 47 | raise ImportError( 48 | "Please install transformers to use VitEncoder. " 49 | "You can install it with: " 50 | "`pip install semantic-router[vision]`" 51 | ) 52 | 53 | try: 54 | import torch 55 | import torchvision.transforms as T 56 | except ImportError: 57 | raise ImportError( 58 | "Please install Pytorch to use VitEncoder. " 59 | "You can install it with: " 60 | "`pip install semantic-router[vision]`" 61 | ) 62 | 63 | try: 64 | from PIL import Image 65 | except ImportError: 66 | raise ImportError( 67 | "Please install PIL to use VitEncoder. " 68 | "You can install it with: " 69 | "`pip install semantic-router[vision]`" 70 | ) 71 | 72 | self._torch = torch 73 | self._Image = Image 74 | self._T = T 75 | 76 | processor = ViTImageProcessor.from_pretrained( 77 | self.name, **self.processor_kwargs 78 | ) 79 | 80 | model = ViTModel.from_pretrained(self.name, **self.model_kwargs) 81 | 82 | self.device = self._get_device() 83 | model.to(self.device) 84 | 85 | return processor, model 86 | 87 | def _get_device(self) -> str: 88 | """Get the device to use for the model. 89 | 90 | :return: The device to use for the model. 91 | :rtype: str 92 | """ 93 | if self.device: 94 | device = self.device 95 | elif self._torch.cuda.is_available(): 96 | device = "cuda" 97 | elif self._torch.backends.mps.is_available(): 98 | device = "mps" 99 | else: 100 | device = "cpu" 101 | return device 102 | 103 | def _process_images(self, images: List[Any]): 104 | """Process the images for the model. 105 | 106 | :param images: The images to process. 107 | :type images: List[Any] 108 | :return: The processed images. 109 | :rtype: Any 110 | """ 111 | rgb_images = [self._ensure_rgb(img) for img in images] 112 | processed_images = self._processor(images=rgb_images, return_tensors="pt") 113 | processed_images = processed_images.to(self.device) 114 | return processed_images 115 | 116 | def _ensure_rgb(self, img: Any): 117 | """Ensure the image is in RGB format. 118 | 119 | :param img: The image to ensure is in RGB format. 120 | :type img: Any 121 | :return: The image in RGB format. 122 | :rtype: Any 123 | """ 124 | rgbimg = self._Image.new("RGB", img.size) 125 | rgbimg.paste(img) 126 | return rgbimg 127 | 128 | def __call__( 129 | self, 130 | imgs: List[Any], 131 | batch_size: int = 32, 132 | ) -> List[List[float]]: 133 | """Encode a list of images into embeddings using the Vision Transformer model. 134 | 135 | :param imgs: The images to encode. 136 | :type imgs: List[Any] 137 | :param batch_size: The batch size for encoding. 138 | :type batch_size: int 139 | :return: The embeddings for the images. 140 | :rtype: List[List[float]] 141 | """ 142 | all_embeddings = [] 143 | for i in range(0, len(imgs), batch_size): 144 | batch_imgs = imgs[i : i + batch_size] 145 | batch_imgs_transform = self._process_images(batch_imgs) 146 | with self._torch.no_grad(): 147 | embeddings = ( 148 | self._model(**batch_imgs_transform) 149 | .last_hidden_state[:, 0] 150 | .cpu() 151 | .tolist() 152 | ) 153 | all_embeddings.extend(embeddings) 154 | return all_embeddings 155 | -------------------------------------------------------------------------------- /semantic_router/encoders/voyage.py: -------------------------------------------------------------------------------- 1 | """This file contains the VoyageEncoder class which is used to encode text using Voyage""" 2 | 3 | import litellm 4 | 5 | from semantic_router.encoders.litellm import LiteLLMEncoder, litellm_to_list 6 | from semantic_router.utils.defaults import EncoderDefault 7 | 8 | 9 | class VoyageEncoder(LiteLLMEncoder): 10 | """Class to encode text using Voyage. Requires a Voyage API key from 11 | https://voyageai.com/api-keys/""" 12 | 13 | type: str = "voyage" 14 | 15 | def __init__( 16 | self, 17 | name: str | None = None, 18 | api_key: str | None = None, 19 | score_threshold: float = 0.4, 20 | ): 21 | """Initialize the VoyageEncoder. 22 | 23 | :param name: The name of the embedding model to use such as "voyage-embed". 24 | :type name: str 25 | :param voyage_api_key: The Voyage API key. 26 | :type voyage_api_key: str 27 | """ 28 | 29 | if name is None: 30 | name = f"voyage/{EncoderDefault.VOYAGE.value['embedding_model']}" 31 | elif not name.startswith("voyage/"): 32 | name = f"voyage/{name}" 33 | super().__init__( 34 | name=name, 35 | score_threshold=score_threshold, 36 | api_key=api_key, 37 | ) 38 | 39 | def encode_queries(self, docs: list[str], **kwargs) -> list[list[float]]: 40 | try: 41 | embeds = litellm.embedding( 42 | input=docs, 43 | input_type="query", 44 | model=f"{self.type}/{self.name}", 45 | **kwargs, 46 | ) 47 | return litellm_to_list(embeds) 48 | except Exception as e: 49 | raise ValueError(f"Voyage API call failed. Error: {e}") from e 50 | 51 | def encode_documents(self, docs: list[str], **kwargs) -> list[list[float]]: 52 | try: 53 | embeds = litellm.embedding( 54 | input=docs, 55 | input_type="document", 56 | model=f"{self.type}/{self.name}", 57 | **kwargs, 58 | ) 59 | return litellm_to_list(embeds) 60 | except Exception as e: 61 | raise ValueError(f"Voyage API call failed. Error: {e}") from e 62 | 63 | async def aencode_queries(self, docs: list[str], **kwargs) -> list[list[float]]: 64 | try: 65 | embeds = await litellm.aembedding( 66 | input=docs, 67 | input_type="query", 68 | model=f"{self.type}/{self.name}", 69 | **kwargs, 70 | ) 71 | return litellm_to_list(embeds) 72 | except Exception as e: 73 | raise ValueError(f"Voyage API call failed. Error: {e}") from e 74 | 75 | async def aencode_documents(self, docs: list[str], **kwargs) -> list[list[float]]: 76 | try: 77 | embeds = await litellm.aembedding( 78 | input=docs, 79 | input_type="document", 80 | model=f"{self.type}/{self.name}", 81 | **kwargs, 82 | ) 83 | return litellm_to_list(embeds) 84 | except Exception as e: 85 | raise ValueError(f"Voyage API call failed. Error: {e}") from e 86 | -------------------------------------------------------------------------------- /semantic_router/index/__init__.py: -------------------------------------------------------------------------------- 1 | from semantic_router.index.base import BaseIndex 2 | from semantic_router.index.hybrid_local import HybridLocalIndex 3 | from semantic_router.index.local import LocalIndex 4 | from semantic_router.index.pinecone import PineconeIndex 5 | from semantic_router.index.postgres import PostgresIndex 6 | from semantic_router.index.qdrant import QdrantIndex 7 | 8 | __all__ = [ 9 | "BaseIndex", 10 | "HybridLocalIndex", 11 | "LocalIndex", 12 | "QdrantIndex", 13 | "PineconeIndex", 14 | "PostgresIndex", 15 | ] 16 | -------------------------------------------------------------------------------- /semantic_router/linear.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | from numpy.linalg import norm 5 | 6 | 7 | def similarity_matrix(xq: np.ndarray, index: np.ndarray) -> np.ndarray: 8 | """Compute the similarity scores between a query vector and a set of vectors. 9 | 10 | :param xq: A query vector (1d ndarray) 11 | :param index: A set of vectors. 12 | :return: The similarity between the query vector and the set of vectors. 13 | :rtype: np.ndarray 14 | """ 15 | 16 | index_norm = norm(index, axis=1) 17 | xq_norm = norm(xq.T) 18 | sim = np.dot(index, xq.T) / (index_norm * xq_norm) 19 | return sim 20 | 21 | 22 | def top_scores(sim: np.ndarray, top_k: int = 5) -> Tuple[np.ndarray, np.ndarray]: 23 | """Get the top scores and indices from a similarity matrix. 24 | 25 | :param sim: A similarity matrix. 26 | :param top_k: The number of top scores to get. 27 | :return: The top scores and indices. 28 | :rtype: Tuple[np.ndarray, np.ndarray] 29 | """ 30 | top_k = min(top_k, sim.shape[0]) 31 | idx = np.argpartition(sim, -top_k)[-top_k:] 32 | scores = sim[idx] 33 | 34 | return scores, idx 35 | -------------------------------------------------------------------------------- /semantic_router/llms/__init__.py: -------------------------------------------------------------------------------- 1 | from semantic_router.llms.base import BaseLLM 2 | from semantic_router.llms.cohere import CohereLLM 3 | from semantic_router.llms.llamacpp import LlamaCppLLM 4 | from semantic_router.llms.mistral import MistralAILLM 5 | from semantic_router.llms.openai import OpenAILLM 6 | from semantic_router.llms.openrouter import OpenRouterLLM 7 | from semantic_router.llms.zure import AzureOpenAILLM 8 | 9 | __all__ = [ 10 | "BaseLLM", 11 | "OpenAILLM", 12 | "LlamaCppLLM", 13 | "OpenRouterLLM", 14 | "CohereLLM", 15 | "AzureOpenAILLM", 16 | "MistralAILLM", 17 | ] 18 | -------------------------------------------------------------------------------- /semantic_router/llms/cohere.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, List, Optional 3 | 4 | from pydantic import PrivateAttr 5 | 6 | from semantic_router.llms import BaseLLM 7 | from semantic_router.schema import Message 8 | 9 | 10 | class CohereLLM(BaseLLM): 11 | """LLM for Cohere. Requires a Cohere API key from https://dashboard.cohere.com/api-keys. 12 | 13 | This class provides functionality to interact with the Cohere API for generating text responses. 14 | It extends the BaseLLM class and implements the __call__ method to generate text responses. 15 | """ 16 | 17 | _client: Any = PrivateAttr() 18 | 19 | def __init__( 20 | self, 21 | name: Optional[str] = None, 22 | cohere_api_key: Optional[str] = None, 23 | ): 24 | """Initialize the CohereLLM. 25 | 26 | :param name: The name of the Cohere model to use can also be set via the 27 | COHERE_CHAT_MODEL_NAME environment variable. 28 | :type name: Optional[str] 29 | :param cohere_api_key: The API key for the Cohere client. Can also be set via the 30 | COHERE_API_KEY environment variable. 31 | :type cohere_api_key: Optional[str] 32 | """ 33 | if name is None: 34 | name = os.getenv("COHERE_CHAT_MODEL_NAME", "command") 35 | super().__init__(name=name) 36 | self._client = self._initialize_client(cohere_api_key) 37 | 38 | def _initialize_client(self, cohere_api_key: Optional[str] = None): 39 | """Initialize the Cohere client. 40 | 41 | :param cohere_api_key: The API key for the Cohere client. Can also be set via the 42 | COHERE_API_KEY environment variable. 43 | :type cohere_api_key: Optional[str] 44 | """ 45 | try: 46 | import cohere 47 | except ImportError: 48 | raise ImportError( 49 | "Please install Cohere to use CohereLLM. " 50 | "You can install it with: " 51 | "`pip install 'semantic-router[cohere]'`" 52 | ) 53 | cohere_api_key = cohere_api_key or os.getenv("COHERE_API_KEY") 54 | if cohere_api_key is None: 55 | raise ValueError("Cohere API key cannot be 'None'.") 56 | try: 57 | client = cohere.Client(cohere_api_key) 58 | except Exception as e: 59 | raise ValueError( 60 | f"Cohere API client failed to initialize. Error: {e}" 61 | ) from e 62 | return client 63 | 64 | def __call__(self, messages: List[Message]) -> str: 65 | """Call the Cohere client. 66 | 67 | :param messages: The messages to pass to the Cohere client. 68 | :type messages: List[Message] 69 | :return: The response from the Cohere client. 70 | :rtype: str 71 | """ 72 | if self._client is None: 73 | raise ValueError("Cohere client is not initialized.") 74 | try: 75 | completion = self._client.chat( 76 | model=self.name, 77 | chat_history=[m.to_cohere() for m in messages[:-1]], 78 | message=messages[-1].content, 79 | ) 80 | 81 | output = completion.text 82 | 83 | if not output: 84 | raise Exception("No output generated") 85 | return output 86 | 87 | except Exception as e: 88 | raise ValueError(f"Cohere API call failed. Error: {e}") from e 89 | -------------------------------------------------------------------------------- /semantic_router/llms/grammars/json.gbnf: -------------------------------------------------------------------------------- 1 | root ::= object 2 | value ::= object | array | string | number | ("true" | "false" | "null") ws 3 | 4 | object ::= 5 | "{" ws ( 6 | string ":" ws value 7 | ("," ws string ":" ws value)* 8 | )? "}" ws 9 | 10 | array ::= 11 | "[" ws ( 12 | value 13 | ("," ws value)* 14 | )? "]" ws 15 | 16 | string ::= 17 | "\"" ( 18 | [^"\\] | 19 | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes 20 | )* "\"" ws 21 | 22 | number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws 23 | 24 | # Optional space: by convention, applied in this grammar after literal chars when allowed 25 | ws ::= ([ \t\n] ws)? 26 | -------------------------------------------------------------------------------- /semantic_router/llms/llamacpp.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from pathlib import Path 3 | from typing import Any, Dict, List, Optional 4 | 5 | from pydantic import PrivateAttr 6 | 7 | from semantic_router.llms.base import BaseLLM 8 | from semantic_router.schema import Message 9 | from semantic_router.utils.logger import logger 10 | 11 | 12 | class LlamaCppLLM(BaseLLM): 13 | """LLM for LlamaCPP. Enables fully local LLM use, helpful for local implementation of 14 | dynamic routes. 15 | """ 16 | 17 | llm: Any 18 | grammar: Optional[Any] = None 19 | _llama_cpp: Any = PrivateAttr() 20 | 21 | def __init__( 22 | self, 23 | llm: Any, 24 | name: str = "llama.cpp", 25 | temperature: float = 0.2, 26 | max_tokens: Optional[int] = 200, 27 | grammar: Optional[Any] = None, 28 | ): 29 | """Initialize the LlamaCPPLLM. 30 | 31 | :param llm: The LLM to use. 32 | :type llm: Any 33 | :param name: The name of the LLM. 34 | :type name: str 35 | :param temperature: The temperature of the LLM. 36 | :type temperature: float 37 | :param max_tokens: The maximum number of tokens to generate. 38 | :type max_tokens: Optional[int] 39 | :param grammar: The grammar to use. 40 | :type grammar: Optional[Any] 41 | """ 42 | super().__init__( 43 | name=name, 44 | llm=llm, 45 | temperature=temperature, 46 | max_tokens=max_tokens, 47 | grammar=grammar, 48 | ) 49 | 50 | try: 51 | import llama_cpp 52 | except ImportError: 53 | raise ImportError( 54 | "Please install LlamaCPP to use Llama CPP llm. " 55 | "You can install it with: " 56 | "`pip install 'semantic-router[local]'`" 57 | ) 58 | self._llama_cpp = llama_cpp 59 | self.llm = llm 60 | self.temperature = temperature 61 | self.max_tokens = max_tokens 62 | self.grammar = grammar 63 | 64 | def __call__( 65 | self, 66 | messages: List[Message], 67 | ) -> str: 68 | """Call the LlamaCPPLLM. 69 | 70 | :param messages: The messages to pass to the LlamaCPPLLM. 71 | :type messages: List[Message] 72 | :return: The response from the LlamaCPPLLM. 73 | :rtype: str 74 | """ 75 | try: 76 | completion = self.llm.create_chat_completion( 77 | messages=[m.to_llamacpp() for m in messages], 78 | temperature=self.temperature, 79 | max_tokens=self.max_tokens, 80 | grammar=self.grammar, 81 | stream=False, 82 | ) 83 | assert isinstance(completion, dict) # keep mypy happy 84 | output = completion["choices"][0]["message"]["content"] 85 | 86 | if not output: 87 | raise Exception("No output generated") 88 | return output 89 | except Exception as e: 90 | logger.error(f"LLM error: {e}") 91 | raise 92 | 93 | @contextmanager 94 | def _grammar(self): 95 | """Context manager for the grammar. 96 | 97 | :return: The grammar. 98 | :rtype: Any 99 | """ 100 | grammar_path = Path(__file__).parent.joinpath("grammars", "json.gbnf") 101 | assert grammar_path.exists(), f"{grammar_path}\ndoes not exist" 102 | try: 103 | self.grammar = self._llama_cpp.LlamaGrammar.from_file(grammar_path) 104 | yield 105 | finally: 106 | self.grammar = None 107 | 108 | def extract_function_inputs( 109 | self, query: str, function_schemas: List[Dict[str, Any]] 110 | ) -> List[Dict[str, Any]]: 111 | """Extract the function inputs from the query. 112 | 113 | :param query: The query to extract the function inputs from. 114 | :type query: str 115 | :param function_schemas: The function schemas to extract the function inputs from. 116 | :type function_schemas: List[Dict[str, Any]] 117 | :return: The function inputs. 118 | :rtype: List[Dict[str, Any]] 119 | """ 120 | with self._grammar(): 121 | return super().extract_function_inputs( 122 | query=query, function_schemas=function_schemas 123 | ) 124 | -------------------------------------------------------------------------------- /semantic_router/llms/mistral.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, List, Optional 3 | 4 | from pydantic import PrivateAttr 5 | 6 | from semantic_router.llms import BaseLLM 7 | from semantic_router.schema import Message 8 | from semantic_router.utils.defaults import EncoderDefault 9 | from semantic_router.utils.logger import logger 10 | 11 | 12 | class MistralAILLM(BaseLLM): 13 | """LLM for MistralAI. Requires a MistralAI API key from https://console.mistral.ai/api-keys/""" 14 | 15 | _client: Any = PrivateAttr() 16 | _mistralai: Any = PrivateAttr() 17 | 18 | def __init__( 19 | self, 20 | name: Optional[str] = None, 21 | mistralai_api_key: Optional[str] = None, 22 | temperature: float = 0.01, 23 | max_tokens: int = 200, 24 | ): 25 | """Initialize the MistralAILLM. 26 | 27 | :param name: The name of the MistralAI model to use. 28 | :type name: Optional[str] 29 | :param mistralai_api_key: The MistralAI API key. 30 | :type mistralai_api_key: Optional[str] 31 | :param temperature: The temperature of the LLM. 32 | :type temperature: float 33 | :param max_tokens: The maximum number of tokens to generate. 34 | :type max_tokens: int 35 | """ 36 | if name is None: 37 | name = EncoderDefault.MISTRAL.value["language_model"] 38 | super().__init__(name=name) 39 | self._client, self._mistralai = self._initialize_client(mistralai_api_key) 40 | self.temperature = temperature 41 | self.max_tokens = max_tokens 42 | 43 | def _initialize_client(self, api_key): 44 | """Initialize the MistralAI client. 45 | 46 | :param api_key: The MistralAI API key. 47 | :type api_key: Optional[str] 48 | :return: The MistralAI client. 49 | :rtype: MistralClient 50 | """ 51 | try: 52 | import mistralai 53 | from mistralai.client import MistralClient 54 | except ImportError: 55 | raise ImportError( 56 | "Please install MistralAI to use MistralAI LLM. " 57 | "You can install it with: " 58 | "`pip install 'semantic-router[mistralai]'`" 59 | ) 60 | api_key = api_key or os.getenv("MISTRALAI_API_KEY") 61 | if api_key is None: 62 | raise ValueError("MistralAI API key cannot be 'None'.") 63 | try: 64 | client = MistralClient(api_key=api_key) 65 | except Exception as e: 66 | raise ValueError( 67 | f"MistralAI API client failed to initialize. Error: {e}" 68 | ) from e 69 | return client, mistralai 70 | 71 | def __call__(self, messages: List[Message]) -> str: 72 | """Call the MistralAILLM. 73 | 74 | :param messages: The messages to pass to the MistralAILLM. 75 | :type messages: List[Message] 76 | :return: The response from the MistralAILLM. 77 | :rtype: str 78 | """ 79 | if self._client is None: 80 | raise ValueError("MistralAI client is not initialized.") 81 | chat_messages = [ 82 | self._mistralai.models.chat_completion.ChatMessage( 83 | role=m.role, content=m.content 84 | ) 85 | for m in messages 86 | ] 87 | try: 88 | completion = self._client.chat( 89 | model=self.name, 90 | messages=chat_messages, 91 | temperature=self.temperature, 92 | max_tokens=self.max_tokens, 93 | ) 94 | 95 | output = completion.choices[0].message.content 96 | 97 | if not output: 98 | raise Exception("No output generated") 99 | return output 100 | except Exception as e: 101 | logger.error(f"LLM error: {e}") 102 | raise Exception(f"LLM error: {e}") from e 103 | -------------------------------------------------------------------------------- /semantic_router/llms/ollama.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import requests 4 | 5 | from semantic_router.llms import BaseLLM 6 | from semantic_router.schema import Message 7 | from semantic_router.utils.logger import logger 8 | 9 | 10 | class OllamaLLM(BaseLLM): 11 | """LLM for Ollama. Enables fully local LLM use, helpful for local implementation of 12 | dynamic routes. 13 | """ 14 | 15 | stream: bool = False 16 | 17 | def __init__( 18 | self, 19 | name: str = "openhermes", 20 | temperature: float = 0.2, 21 | max_tokens: Optional[int] = 200, 22 | stream: bool = False, 23 | ): 24 | """Initialize the OllamaLLM. 25 | 26 | :param name: The name of the Ollama model to use. 27 | :type name: str 28 | :param temperature: The temperature of the LLM. 29 | :type temperature: float 30 | :param max_tokens: The maximum number of tokens to generate. 31 | :type max_tokens: Optional[int] 32 | :param stream: Whether to stream the response. 33 | :type stream: bool 34 | """ 35 | super().__init__(name=name) 36 | self.temperature = temperature 37 | self.max_tokens = max_tokens 38 | self.stream = stream 39 | 40 | def __call__( 41 | self, 42 | messages: List[Message], 43 | temperature: Optional[float] = None, 44 | name: Optional[str] = None, 45 | max_tokens: Optional[int] = None, 46 | stream: Optional[bool] = None, 47 | ) -> str: 48 | """Call the OllamaLLM. 49 | 50 | :param messages: The messages to pass to the OllamaLLM. 51 | :type messages: List[Message] 52 | :param temperature: The temperature of the LLM. 53 | :type temperature: Optional[float] 54 | :param name: The name of the Ollama model to use. 55 | :type name: Optional[str] 56 | :param max_tokens: The maximum number of tokens to generate. 57 | :type max_tokens: Optional[int] 58 | :param stream: Whether to stream the response. 59 | :type stream: Optional[bool] 60 | """ 61 | # Use instance defaults if not overridden 62 | temperature = temperature if temperature is not None else self.temperature 63 | name = name if name is not None else self.name 64 | max_tokens = max_tokens if max_tokens is not None else self.max_tokens 65 | stream = stream if stream is not None else self.stream 66 | 67 | try: 68 | payload = { 69 | "model": name, 70 | "messages": [m.to_openai() for m in messages], 71 | "options": {"temperature": temperature, "num_predict": max_tokens}, 72 | "format": "json", 73 | "stream": stream, 74 | } 75 | response = requests.post("http://localhost:11434/api/chat", json=payload) 76 | output = response.json()["message"]["content"] 77 | 78 | return output 79 | except Exception as e: 80 | logger.error(f"LLM error: {e}") 81 | raise Exception(f"LLM error: {e}") from e 82 | -------------------------------------------------------------------------------- /semantic_router/llms/openrouter.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Optional 3 | 4 | import openai 5 | from pydantic import PrivateAttr 6 | 7 | from semantic_router.llms import BaseLLM 8 | from semantic_router.schema import Message 9 | from semantic_router.utils.logger import logger 10 | 11 | 12 | class OpenRouterLLM(BaseLLM): 13 | """LLM for OpenRouter. Requires an OpenRouter API key, see here for more information 14 | https://openrouter.ai/docs/api-reference/authentication#using-an-api-key""" 15 | 16 | _client: Optional[openai.OpenAI] = PrivateAttr(default=None) 17 | _base_url: str = PrivateAttr(default="https://openrouter.ai/api/v1") 18 | 19 | def __init__( 20 | self, 21 | name: Optional[str] = None, 22 | openrouter_api_key: Optional[str] = None, 23 | base_url: str = "https://openrouter.ai/api/v1", 24 | temperature: float = 0.01, 25 | max_tokens: int = 200, 26 | ): 27 | """Initialize the OpenRouterLLM. 28 | 29 | :param name: The name of the OpenRouter model to use. 30 | :type name: Optional[str] 31 | :param openrouter_api_key: The OpenRouter API key. 32 | :type openrouter_api_key: Optional[str] 33 | :param base_url: The base URL for the OpenRouter API. 34 | :type base_url: str 35 | :param temperature: The temperature of the LLM. 36 | :type temperature: float 37 | :param max_tokens: The maximum number of tokens to generate. 38 | :type max_tokens: int 39 | """ 40 | if name is None: 41 | name = os.getenv( 42 | "OPENROUTER_CHAT_MODEL_NAME", "mistralai/mistral-7b-instruct" 43 | ) 44 | super().__init__(name=name) 45 | self._base_url = base_url 46 | api_key = openrouter_api_key or os.getenv("OPENROUTER_API_KEY") 47 | if api_key is None: 48 | raise ValueError("OpenRouter API key cannot be 'None'.") 49 | try: 50 | self._client = openai.OpenAI(api_key=api_key, base_url=self._base_url) 51 | except Exception as e: 52 | raise ValueError( 53 | f"OpenRouter API client failed to initialize. Error: {e}" 54 | ) from e 55 | self.temperature = temperature 56 | self.max_tokens = max_tokens 57 | 58 | def __call__(self, messages: List[Message]) -> str: 59 | """Call the OpenRouterLLM. 60 | 61 | :param messages: The messages to pass to the OpenRouterLLM. 62 | :type messages: List[Message] 63 | :return: The response from the OpenRouterLLM. 64 | :rtype: str 65 | """ 66 | if self._client is None: 67 | raise ValueError("OpenRouter client is not initialized.") 68 | try: 69 | completion = self._client.chat.completions.create( 70 | model=self.name, 71 | messages=[m.to_openai() for m in messages], 72 | temperature=self.temperature, 73 | max_tokens=self.max_tokens, 74 | ) 75 | 76 | output = completion.choices[0].message.content 77 | 78 | if not output: 79 | raise Exception("No output generated") 80 | return output 81 | except Exception as e: 82 | logger.error(f"LLM error: {e}") 83 | raise Exception(f"LLM error: {e}") from e 84 | -------------------------------------------------------------------------------- /semantic_router/llms/zure.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Optional 3 | 4 | import openai 5 | from pydantic import PrivateAttr 6 | 7 | from semantic_router.llms import BaseLLM 8 | from semantic_router.schema import Message 9 | from semantic_router.utils.defaults import EncoderDefault 10 | from semantic_router.utils.logger import logger 11 | 12 | 13 | class AzureOpenAILLM(BaseLLM): 14 | """LLM for Azure OpenAI. Requires an Azure OpenAI API key.""" 15 | 16 | _client: Optional[openai.AzureOpenAI] = PrivateAttr(default=None) 17 | 18 | def __init__( 19 | self, 20 | name: Optional[str] = None, 21 | openai_api_key: Optional[str] = None, 22 | azure_endpoint: Optional[str] = None, 23 | temperature: float = 0.01, 24 | max_tokens: int = 200, 25 | api_version="2023-07-01-preview", 26 | ): 27 | """Initialize the AzureOpenAILLM. 28 | 29 | :param name: The name of the Azure OpenAI model to use. 30 | :type name: Optional[str] 31 | :param openai_api_key: The Azure OpenAI API key. 32 | :type openai_api_key: Optional[str] 33 | :param azure_endpoint: The Azure OpenAI endpoint. 34 | :type azure_endpoint: Optional[str] 35 | :param temperature: The temperature of the LLM. 36 | :type temperature: float 37 | :param max_tokens: The maximum number of tokens to generate. 38 | :type max_tokens: int 39 | :param api_version: The API version to use. 40 | :type api_version: str 41 | """ 42 | if name is None: 43 | name = EncoderDefault.AZURE.value["language_model"] 44 | super().__init__(name=name) 45 | api_key = openai_api_key or os.getenv("AZURE_OPENAI_API_KEY") 46 | if api_key is None: 47 | raise ValueError("AzureOpenAI API key cannot be 'None'.") 48 | azure_endpoint = azure_endpoint or os.getenv("AZURE_OPENAI_ENDPOINT") 49 | if azure_endpoint is None: 50 | raise ValueError("Azure endpoint API key cannot be 'None'.") 51 | try: 52 | self._client = openai.AzureOpenAI( 53 | api_key=api_key, azure_endpoint=azure_endpoint, api_version=api_version 54 | ) 55 | except Exception as e: 56 | raise ValueError(f"AzureOpenAI API client failed to initialize. Error: {e}") 57 | self.temperature = temperature 58 | self.max_tokens = max_tokens 59 | 60 | def __call__(self, messages: List[Message]) -> str: 61 | """Call the AzureOpenAILLM. 62 | 63 | :param messages: The messages to pass to the AzureOpenAILLM. 64 | :type messages: List[Message] 65 | :return: The response from the AzureOpenAILLM. 66 | :rtype: str 67 | """ 68 | if self._client is None: 69 | raise ValueError("AzureOpenAI client is not initialized.") 70 | try: 71 | completion = self._client.chat.completions.create( 72 | model=self.name, 73 | messages=[m.to_openai() for m in messages], 74 | temperature=self.temperature, 75 | max_tokens=self.max_tokens, 76 | ) 77 | 78 | output = completion.choices[0].message.content 79 | 80 | if not output: 81 | raise Exception("No output generated") 82 | return output 83 | except Exception as e: 84 | logger.error(f"LLM error: {e}") 85 | raise Exception(f"LLM error: {e}") from e 86 | -------------------------------------------------------------------------------- /semantic_router/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aurelio-labs/semantic-router/5695fcaf1d87043cf7456de9ebc2a9776cf075e8/semantic_router/py.typed -------------------------------------------------------------------------------- /semantic_router/routers/__init__.py: -------------------------------------------------------------------------------- 1 | from semantic_router.routers.base import BaseRouter, RouterConfig 2 | from semantic_router.routers.hybrid import HybridRouter 3 | from semantic_router.routers.semantic import SemanticRouter 4 | 5 | __all__ = [ 6 | "BaseRouter", 7 | "RouterConfig", 8 | "SemanticRouter", 9 | "HybridRouter", 10 | ] 11 | -------------------------------------------------------------------------------- /semantic_router/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aurelio-labs/semantic-router/5695fcaf1d87043cf7456de9ebc2a9776cf075e8/semantic_router/utils/__init__.py -------------------------------------------------------------------------------- /semantic_router/utils/defaults.py: -------------------------------------------------------------------------------- 1 | import os 2 | from enum import Enum 3 | 4 | 5 | class EncoderDefault(Enum): 6 | """Default model names for each encoder type.""" 7 | 8 | FASTEMBED = { 9 | "embedding_model": "BAAI/bge-small-en-v1.5", 10 | "language_model": "BAAI/bge-small-en-v1.5", 11 | } 12 | OPENAI = { 13 | "embedding_model": os.getenv("OPENAI_MODEL_NAME", "text-embedding-3-small"), 14 | "language_model": os.getenv("OPENAI_CHAT_MODEL_NAME", "gpt-4o"), 15 | } 16 | COHERE = { 17 | "embedding_model": os.getenv("COHERE_MODEL_NAME", "embed-english-v3.0"), 18 | "language_model": os.getenv("COHERE_CHAT_MODEL_NAME", "command"), 19 | } 20 | MISTRAL = { 21 | "embedding_model": os.getenv("MISTRAL_MODEL_NAME", "mistral-embed"), 22 | "language_model": os.getenv("MISTRALAI_CHAT_MODEL_NAME", "mistral-tiny"), 23 | } 24 | VOYAGE = { 25 | "embedding_model": os.getenv("VOYAGE_MODEL_NAME", "voyage-3-lite"), 26 | "language_model": os.getenv("VOYAGE_CHAT_MODEL_NAME", "voyage-3-lite"), 27 | } 28 | JINA = { 29 | "embedding_model": os.getenv("JINA_MODEL_NAME", "jina-embeddings-v3"), 30 | "language_model": os.getenv("JINA_CHAT_MODEL_NAME", "ReaderLM-v2"), 31 | } 32 | NVIDIA_NIM = { 33 | "embedding_model": os.getenv( 34 | "NVIDIA_NIM_MODEL_NAME", "nvidia/nv-embedqa-e5-v5" 35 | ), 36 | "language_model": os.getenv( 37 | "NVIDIA_NIM_CHAT_MODEL_NAME", "meta/llama3-70b-instruct" 38 | ), 39 | } 40 | AZURE = { 41 | "embedding_model": os.getenv("AZURE_OPENAI_MODEL", "text-embedding-3-small"), 42 | "language_model": os.getenv("OPENAI_CHAT_MODEL_NAME", "gpt-4o"), 43 | "deployment_name": os.getenv( 44 | "AZURE_OPENAI_DEPLOYMENT_NAME", "text-embedding-3-small" 45 | ), 46 | } 47 | GOOGLE = { 48 | "embedding_model": os.getenv( 49 | "GOOGLE_EMBEDDING_MODEL", "textembedding-gecko@003" 50 | ), 51 | } 52 | OLLAMA = { 53 | "embedding_model": os.getenv( 54 | "OLLAMA_EMBEDDING_MODEL", "hf.co/Qwen/Qwen3-Embedding-0.6B-GGUF:F16" 55 | ) 56 | } 57 | BEDROCK = { 58 | "embedding_model": os.environ.get( 59 | "BEDROCK_EMBEDDING_MODEL", "amazon.titan-embed-image-v1" 60 | ) 61 | } 62 | -------------------------------------------------------------------------------- /semantic_router/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import colorlog 5 | 6 | 7 | class CustomFormatter(colorlog.ColoredFormatter): 8 | """Custom formatter for the logger.""" 9 | 10 | def __init__(self): 11 | super().__init__( 12 | "%(log_color)s%(asctime)s %(levelname)s %(name)s %(message)s", 13 | datefmt="%Y-%m-%d %H:%M:%S", 14 | log_colors={ 15 | "DEBUG": "cyan", 16 | "INFO": "green", 17 | "WARNING": "yellow", 18 | "ERROR": "red", 19 | "CRITICAL": "bold_red", 20 | }, 21 | reset=True, 22 | style="%", 23 | ) 24 | 25 | 26 | def add_coloured_handler(logger): 27 | """Add a coloured handler to the logger.""" 28 | formatter = CustomFormatter() 29 | console_handler = logging.StreamHandler() 30 | console_handler.setFormatter(formatter) 31 | logger.addHandler(console_handler) 32 | return logger 33 | 34 | 35 | def setup_custom_logger(name): 36 | """Setup a custom logger.""" 37 | logger = logging.getLogger(name) 38 | 39 | # get log level from environment vars 40 | # first check for semantic_router_log_level, then log_level, then default to INFO 41 | log_level = ( 42 | os.getenv("SEMANTIC_ROUTER_LOG_LEVEL") or os.getenv("LOG_LEVEL") or "INFO" 43 | ) 44 | log_level = log_level.upper() 45 | 46 | add_coloured_handler(logger) 47 | logger.setLevel(log_level) 48 | logger.propagate = False 49 | 50 | return logger 51 | 52 | 53 | logger: logging.Logger = setup_custom_logger("semantic_router") 54 | -------------------------------------------------------------------------------- /tests/functional/encoders/test_bm25_functional.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | from semantic_router import Route 7 | from semantic_router.encoders.bm25 import BM25Encoder 8 | 9 | UTTERANCES = [ 10 | "A high weight in tf–idf is reached by a high term frequency", 11 | "(in the given document) and a low document frequency of the term", 12 | "in the whole collection of documents; the weights hence tend to filter", 13 | "out common terms. Since the ratio inside the idf's log function is always", 14 | "greater than or equal to 1, the value of idf (and tf–idf) is greater than or equal", 15 | "to 0. As a term appears in more documents, the ratio inside the logarithm approaches", 16 | "1, bringing the idf and tf–idf closer to 0.", 17 | ] 18 | QUERIES = ["weights", "ratio logarithm"] 19 | 20 | 21 | @pytest.fixture 22 | @pytest.mark.skipif( 23 | os.environ.get("RUN_HF_TESTS") is None, 24 | reason="Set RUN_HF_TESTS=1 to run. This test downloads models from Hugging Face which can time out in CI.", 25 | ) 26 | def bm25_encoder(): 27 | sparse_encoder = BM25Encoder(use_default_params=True) 28 | sparse_encoder.fit([Route(name="test_route", utterances=UTTERANCES)]) 29 | return sparse_encoder 30 | 31 | 32 | class TestBM25Encoder: 33 | def _sparse_to_vector(self, sparse_embedding, vocab_size): 34 | """Re-constructs the full (sparse_embedding.shape[0], vocab_size) array""" 35 | return ( 36 | np.eye(vocab_size)[sparse_embedding[:, 0].astype(np.uint).tolist()] 37 | * np.atleast_2d(sparse_embedding[:, 1]).T 38 | ).sum(axis=0) 39 | 40 | @pytest.mark.skipif( 41 | os.environ.get("RUN_HF_TESTS") is None, 42 | reason="Set RUN_HF_TESTS=1 to run. This test downloads models from Hugging Face which can time out in CI.", 43 | ) 44 | def test_bm25_scoring(self, bm25_encoder): 45 | vocab_size = bm25_encoder._tokenizer.vocab_size 46 | expected = np.array( 47 | [ 48 | [0.00000, 0.00000, 0.54575, 0.00000, 0.00000, 0.00000, 0.00000], 49 | [0.00000, 0.00000, 0.00000, 0.18864, 0.00000, 0.67897, 0.00000], 50 | ] 51 | ) 52 | q_e = np.stack( 53 | [ 54 | self._sparse_to_vector(v.embedding, vocab_size=vocab_size) 55 | for v in bm25_encoder.encode_queries(QUERIES) 56 | ] 57 | ) 58 | d_e = np.stack( 59 | [ 60 | self._sparse_to_vector(v.embedding, vocab_size=vocab_size) 61 | for v in bm25_encoder.encode_documents(UTTERANCES) 62 | ] 63 | ) 64 | scores = q_e @ d_e.T 65 | assert np.allclose(scores, expected, rtol=1e-4), expected 66 | -------------------------------------------------------------------------------- /tests/functional/test_linear.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from semantic_router.linear import similarity_matrix, top_scores 5 | 6 | 7 | @pytest.fixture 8 | def ident_vector(): 9 | return np.identity(10)[0] 10 | 11 | 12 | @pytest.fixture 13 | def test_index(): 14 | return np.array([[3, 0, 0], [2, 1, 0], [0, 1, 0]]) 15 | 16 | 17 | def test_similarity_matrix__dimensionality(): 18 | """Test that the similarity matrix is square.""" 19 | xq = np.random.random((10,)) # 10-dimensional embedding vector 20 | index = np.random.random((100, 10)) 21 | S = similarity_matrix(xq, index) 22 | assert S.shape == (100,) 23 | 24 | 25 | def test_similarity_matrix__is_norm_max(ident_vector): 26 | """ 27 | Using identical vectors should yield a maximum similarity of 1 28 | """ 29 | index = np.repeat(np.atleast_2d(ident_vector), 3, axis=0) 30 | sim = similarity_matrix(ident_vector, index) 31 | assert sim.max() == 1.0 32 | 33 | 34 | def test_similarity_matrix__is_norm_min(ident_vector): 35 | """ 36 | Using orthogonal vectors should yield a minimum similarity of 0 37 | """ 38 | orth_v = np.roll(np.atleast_2d(ident_vector), 1) 39 | index = np.repeat(orth_v, 3, axis=0) 40 | sim = similarity_matrix(ident_vector, index) 41 | assert sim.min() == 0.0 42 | 43 | 44 | def test_top_scores__is_sorted(test_index): 45 | """ 46 | Test that the top_scores function returns a sorted list of scores. 47 | """ 48 | 49 | xq = test_index[0] # should have max similarity 50 | 51 | sim = similarity_matrix(xq, test_index) 52 | _, idx = top_scores(sim, 3) 53 | 54 | # Scores and indexes should be sorted ascending 55 | assert np.array_equal(idx, np.array([2, 1, 0])) 56 | 57 | 58 | def test_top_scores__scores(test_index): 59 | """ 60 | Test that for a known vector and a known index, the top_scores function 61 | returns exactly the expected scores. 62 | """ 63 | xq = test_index[0] # should have max similarity 64 | 65 | sim = similarity_matrix(xq, test_index) 66 | scores, _ = top_scores(sim, 3) 67 | 68 | # Scores and indexes should be sorted ascending 69 | assert np.allclose(scores, np.array([0.0, 0.89442719, 1.0])) 70 | -------------------------------------------------------------------------------- /tests/integration/encoders/test_openai_integration.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | from openai import OpenAIError 5 | 6 | from semantic_router.encoders.base import DenseEncoder 7 | from semantic_router.encoders.openai import OpenAIEncoder 8 | 9 | with open("tests/integration/57640.4032.txt", "r") as fp: 10 | long_doc = fp.read() 11 | 12 | 13 | def has_valid_openai_api_key(): 14 | """Check if a valid OpenAI API key is available.""" 15 | api_key = os.environ.get("OPENAI_API_KEY") 16 | return api_key is not None and api_key.strip() != "" 17 | 18 | 19 | @pytest.fixture 20 | def openai_encoder(): 21 | if not has_valid_openai_api_key(): 22 | return DenseEncoder() 23 | else: 24 | return OpenAIEncoder() 25 | 26 | 27 | class TestOpenAIEncoder: 28 | @pytest.mark.skipif( 29 | not has_valid_openai_api_key(), reason="OpenAI API key required" 30 | ) 31 | def test_openai_encoder_init_success(self, openai_encoder): 32 | assert openai_encoder._client is not None 33 | 34 | @pytest.mark.skipif( 35 | not has_valid_openai_api_key(), reason="OpenAI API key required" 36 | ) 37 | def test_openai_encoder_dims(self, openai_encoder): 38 | embeddings = openai_encoder(["test document"]) 39 | assert len(embeddings) == 1 40 | assert len(embeddings[0]) == 1536 41 | 42 | @pytest.mark.skipif( 43 | not has_valid_openai_api_key(), reason="OpenAI API key required" 44 | ) 45 | def test_openai_encoder_call_truncation(self, openai_encoder): 46 | openai_encoder([long_doc]) 47 | 48 | @pytest.mark.skipif( 49 | not has_valid_openai_api_key(), reason="OpenAI API key required" 50 | ) 51 | def test_openai_encoder_call_no_truncation(self, openai_encoder): 52 | with pytest.raises(OpenAIError) as _: 53 | # default truncation is True 54 | openai_encoder([long_doc], truncate=False) 55 | 56 | @pytest.mark.skipif( 57 | not has_valid_openai_api_key(), reason="OpenAI API key required" 58 | ) 59 | def test_openai_encoder_call_uninitialized_client(self, openai_encoder): 60 | # Set the client to None to simulate an uninitialized client 61 | openai_encoder._client = None 62 | with pytest.raises(ValueError) as e: 63 | openai_encoder(["test document"]) 64 | assert "OpenAI client is not initialized." in str(e.value) 65 | -------------------------------------------------------------------------------- /tests/unit/encoders/test_base.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from semantic_router.encoders import DenseEncoder 4 | 5 | 6 | class TestDenseEncoder: 7 | @pytest.fixture 8 | def base_encoder(self): 9 | return DenseEncoder(name="TestEncoder", score_threshold=0.5) 10 | 11 | def test_base_encoder_initialization(self, base_encoder): 12 | assert base_encoder.name == "TestEncoder", "Initialization of name failed" 13 | assert base_encoder.score_threshold == 0.5 14 | 15 | def test_base_encoder_call_method_not_implemented(self, base_encoder): 16 | with pytest.raises(NotImplementedError): 17 | base_encoder(["some", "texts"]) 18 | -------------------------------------------------------------------------------- /tests/unit/encoders/test_bm25.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | from semantic_router.encoders import BM25Encoder 7 | from semantic_router.route import Route 8 | 9 | UTTERANCES = [ 10 | "Hello we need this text to be a little longer for our sparse encoders", 11 | "In this case they need to learn from recurring tokens, ie words.", 12 | "We give ourselves several examples from our encoders to learn from.", 13 | "But given this is only an example we don't need too many", 14 | "Just enough to test that our sparse encoders work as expected", 15 | ] 16 | 17 | 18 | @pytest.fixture 19 | @pytest.mark.skipif( 20 | os.environ.get("RUN_HF_TESTS") is None, 21 | reason="Set RUN_HF_TESTS=1 to run. This test downloads models from Hugging Face which can time out in CI.", 22 | ) 23 | def bm25_encoder(): 24 | sparse_encoder = BM25Encoder(use_default_params=True) 25 | sparse_encoder.fit( 26 | [ 27 | Route( 28 | name="test_route", 29 | utterances=[ 30 | "The quick brown fox", 31 | "jumps over the lazy dog", 32 | "Hello, world!", 33 | ], 34 | ) 35 | ] 36 | ) 37 | return sparse_encoder 38 | 39 | 40 | @pytest.fixture 41 | def routes(): 42 | return [ 43 | Route(name="Route 1", utterances=[UTTERANCES[0], UTTERANCES[1]]), 44 | Route(name="Route 2", utterances=[UTTERANCES[2], UTTERANCES[3], UTTERANCES[4]]), 45 | ] 46 | 47 | 48 | class TestBM25Encoder: 49 | def test_initialization(self, bm25_encoder): 50 | assert bm25_encoder._tokenizer is not None 51 | 52 | def test_fit(self, bm25_encoder, routes): 53 | bm25_encoder.fit(routes) 54 | assert bm25_encoder._tokenizer is not None 55 | 56 | def test_fit_with_strings(self, bm25_encoder): 57 | route_strings = ["test a", "test b", "test c"] 58 | with pytest.raises(TypeError): 59 | bm25_encoder.fit(route_strings) 60 | 61 | def test_call_method(self, bm25_encoder): 62 | result = bm25_encoder(["test"]) 63 | assert isinstance(result, list), "Result should be a list" 64 | assert all( 65 | isinstance(sparse_emb.embedding, np.ndarray) for sparse_emb in result 66 | ), "Each item in result should be an array" 67 | 68 | def test_call_method_no_docs_bm25_encoder(self, bm25_encoder): 69 | with pytest.raises(ValueError): 70 | bm25_encoder([]) 71 | 72 | def test_call_method_no_word(self, bm25_encoder): 73 | result = bm25_encoder(["doc with fake word gta5jabcxyz"]) 74 | assert isinstance(result, list), "Result should be a list" 75 | assert all( 76 | isinstance(sparse_emb.embedding, np.ndarray) for sparse_emb in result 77 | ), "Each item in result should be an array" 78 | 79 | def test_call_method_with_uninitialized_model_or_mapping(self, bm25_encoder): 80 | bm25_encoder._tokenizer = None 81 | with pytest.raises(ValueError): 82 | bm25_encoder(["test"]) 83 | 84 | def test_fit_with_uninitialized_model(self, bm25_encoder, routes): 85 | bm25_encoder._tokenizer = None 86 | with pytest.raises(ValueError): 87 | bm25_encoder.fit(routes) 88 | 89 | def test_encode_queries(self, bm25_encoder): 90 | queries = ["quick brown", "lazy dog", "hello world"] 91 | results = bm25_encoder.encode_queries(queries) 92 | 93 | assert len(results) == len(queries) 94 | assert all([isinstance(result.embedding, np.ndarray) for result in results]) 95 | 96 | def test_encode_queries_empty_list(self, bm25_encoder): 97 | with pytest.raises(ValueError, match="No documents provided for encoding"): 98 | bm25_encoder.encode_queries([]) 99 | 100 | @pytest.mark.skipif( 101 | os.environ.get("RUN_HF_TESTS") is None, 102 | reason="Set RUN_HF_TESTS=1 to run. This test downloads models from Hugging Face which can time out in CI.", 103 | ) 104 | def test_encode_queries_unfitted(self): 105 | encoder = BM25Encoder(use_default_params=True) 106 | with pytest.raises(ValueError, match="Encoder not fitted"): 107 | encoder.encode_queries(["test query"]) 108 | 109 | def test_encode_documents(self, bm25_encoder): 110 | documents = ["quick brown", "lazy dog", "hello world"] 111 | results = bm25_encoder.encode_documents(documents) 112 | 113 | assert len(results) == len(documents) 114 | assert all([isinstance(result.embedding, np.ndarray) for result in results]) 115 | 116 | def test_encode_documents_empty_list(self, bm25_encoder): 117 | with pytest.raises(ValueError, match="No documents provided for encoding"): 118 | bm25_encoder.encode_documents([]) 119 | 120 | @pytest.mark.skipif( 121 | os.environ.get("RUN_HF_TESTS") is None, 122 | reason="Set RUN_HF_TESTS=1 to run. This test downloads models from Hugging Face which can time out in CI.", 123 | ) 124 | def test_encode_documents_unfitted(self): 125 | encoder = BM25Encoder(use_default_params=True) 126 | with pytest.raises(ValueError, match="Encoder not fitted"): 127 | encoder.encode_documents(["test document"]) 128 | 129 | def test_encode_documents_batch_size(self, bm25_encoder): 130 | documents = ["quick brown", "lazy dog", "hello world", "test document"] 131 | batch_size = 2 132 | results = bm25_encoder.encode_documents(documents, batch_size=batch_size) 133 | 134 | assert len(results) == len(documents) 135 | assert all(isinstance(result.embedding, np.ndarray) for result in results) 136 | -------------------------------------------------------------------------------- /tests/unit/encoders/test_clip.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | _ = pytest.importorskip("torch") 7 | 8 | from unittest.mock import patch # noqa: E402 9 | 10 | import torch # noqa: E402 11 | from PIL import Image # noqa: E402 12 | 13 | from semantic_router.encoders import CLIPEncoder # noqa: E402 14 | 15 | test_model_name = "aurelio-ai/sr-test-clip" 16 | embed_dim = 64 17 | 18 | if torch.cuda.is_available(): 19 | device = "cuda" 20 | elif torch.backends.mps.is_available(): 21 | device = "mps" 22 | else: 23 | device = "cpu" 24 | 25 | 26 | @pytest.fixture() 27 | def dummy_pil_image(): 28 | return Image.fromarray(np.random.rand(512, 224, 3).astype(np.uint8)) 29 | 30 | 31 | @pytest.fixture() 32 | def dummy_black_and_white_img(): 33 | return Image.fromarray(np.random.rand(224, 224, 2).astype(np.uint8)) 34 | 35 | 36 | @pytest.fixture() 37 | def misshaped_pil_image(): 38 | return Image.fromarray(np.random.rand(64, 64, 3).astype(np.uint8)) 39 | 40 | 41 | class TestClipEncoder: 42 | @pytest.mark.skipif( 43 | os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run" 44 | ) 45 | def test_clip_encoder__import_errors_transformers(self): 46 | with patch.dict("sys.modules", {"transformers": None}): 47 | with pytest.raises(ImportError) as error: 48 | CLIPEncoder() 49 | 50 | assert "install transformers" in str(error.value) 51 | 52 | @pytest.mark.skipif( 53 | os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run" 54 | ) 55 | def test_clip_encoder__import_errors_torch(self): 56 | with patch.dict("sys.modules", {"torch": None}): 57 | with pytest.raises(ImportError) as error: 58 | CLIPEncoder() 59 | 60 | assert "install Pytorch" in str(error.value) 61 | 62 | @pytest.mark.skipif( 63 | os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run" 64 | ) 65 | def test_clip_encoder_initialization(self): 66 | clip_encoder = CLIPEncoder(name=test_model_name) 67 | assert clip_encoder.name == test_model_name 68 | assert clip_encoder.type == "huggingface" 69 | assert clip_encoder.score_threshold == 0.2 70 | assert clip_encoder.device == device 71 | 72 | @pytest.mark.skipif( 73 | os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run" 74 | ) 75 | def test_clip_encoder_call_text(self): 76 | clip_encoder = CLIPEncoder(name=test_model_name) 77 | embeddings = clip_encoder(["hello", "world"]) 78 | 79 | assert len(embeddings) == 2 80 | assert len(embeddings[0]) == embed_dim 81 | 82 | @pytest.mark.skipif( 83 | os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run" 84 | ) 85 | def test_clip_encoder_call_image(self, dummy_pil_image): 86 | clip_encoder = CLIPEncoder(name=test_model_name) 87 | encoded_images = clip_encoder([dummy_pil_image] * 3) 88 | 89 | assert len(encoded_images) == 3 90 | assert set(map(len, encoded_images)) == {embed_dim} 91 | 92 | @pytest.mark.skipif( 93 | os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run" 94 | ) 95 | def test_clip_encoder_call_misshaped(self, dummy_pil_image, misshaped_pil_image): 96 | clip_encoder = CLIPEncoder(name=test_model_name) 97 | encoded_images = clip_encoder([dummy_pil_image, misshaped_pil_image]) 98 | 99 | assert len(encoded_images) == 2 100 | assert set(map(len, encoded_images)) == {embed_dim} 101 | 102 | @pytest.mark.skipif( 103 | os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run" 104 | ) 105 | def test_clip_device(self): 106 | clip_encoder = CLIPEncoder(name=test_model_name) 107 | device = clip_encoder._model.device.type 108 | assert device == device 109 | 110 | @pytest.mark.skipif( 111 | os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run" 112 | ) 113 | def test_clip_encoder_ensure_rgb(self, dummy_black_and_white_img): 114 | clip_encoder = CLIPEncoder(name=test_model_name) 115 | rgb_image = clip_encoder._ensure_rgb(dummy_black_and_white_img) 116 | 117 | assert rgb_image.mode == "RGB" 118 | assert np.array(rgb_image).shape == (224, 224, 3) 119 | -------------------------------------------------------------------------------- /tests/unit/encoders/test_fastembed.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from semantic_router.encoders import FastEmbedEncoder 4 | 5 | _ = pytest.importorskip("fastembed") 6 | 7 | 8 | class TestFastEmbedEncoder: 9 | def test_fastembed_encoder(self): 10 | encode = FastEmbedEncoder() 11 | test_docs = ["This is a test", "This is another test"] 12 | embeddings = encode(test_docs) 13 | assert isinstance(embeddings, list) 14 | -------------------------------------------------------------------------------- /tests/unit/encoders/test_google.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from google.api_core.exceptions import GoogleAPICallError 3 | from vertexai.language_models import TextEmbedding 4 | from vertexai.language_models._language_models import TextEmbeddingStatistics 5 | 6 | from semantic_router.encoders import GoogleEncoder 7 | 8 | 9 | @pytest.fixture 10 | def google_encoder(mocker): 11 | mocker.patch("google.cloud.aiplatform.init") 12 | mocker.patch("vertexai.language_models.TextEmbeddingModel.from_pretrained") 13 | return GoogleEncoder(project_id="test_project_id") 14 | 15 | 16 | class TestGoogleEncoder: 17 | def test_initialization_with_project_id(self, google_encoder): 18 | assert google_encoder.client is not None, "Client should be initialized" 19 | assert google_encoder.name == "textembedding-gecko@003", ( 20 | "Default name not set correctly" 21 | ) 22 | 23 | def test_initialization_without_project_id(self, mocker, monkeypatch): 24 | monkeypatch.delenv("GOOGLE_PROJECT_ID", raising=False) 25 | mocker.patch("google.cloud.aiplatform.init") 26 | mocker.patch("vertexai.language_models.TextEmbeddingModel.from_pretrained") 27 | with pytest.raises(ValueError): 28 | GoogleEncoder() 29 | 30 | def test_call_method(self, google_encoder, mocker): 31 | mock_embeddings = [ 32 | TextEmbedding( 33 | values=[0.1, 0.2, 0.3], 34 | statistics=TextEmbeddingStatistics(token_count=5, truncated=False), 35 | ) 36 | ] 37 | mocker.patch.object( 38 | google_encoder.client, "get_embeddings", return_value=mock_embeddings 39 | ) 40 | 41 | result = google_encoder(["test"]) 42 | assert isinstance(result, list), "Result should be a list" 43 | assert all(isinstance(sublist, list) for sublist in result), ( 44 | "Each item in result should be a list" 45 | ) 46 | google_encoder.client.get_embeddings.assert_called_once() 47 | 48 | def test_returns_list_of_embeddings_for_valid_input(self, google_encoder, mocker): 49 | mock_embeddings = [ 50 | TextEmbedding( 51 | values=[0.1, 0.2, 0.3], 52 | statistics=TextEmbeddingStatistics(token_count=5, truncated=False), 53 | ) 54 | ] 55 | mocker.patch.object( 56 | google_encoder.client, "get_embeddings", return_value=mock_embeddings 57 | ) 58 | 59 | result = google_encoder(["test"]) 60 | assert isinstance(result, list), "Result should be a list" 61 | assert all(isinstance(sublist, list) for sublist in result), ( 62 | "Each item in result should be a list" 63 | ) 64 | google_encoder.client.get_embeddings.assert_called_once() 65 | 66 | def test_handles_multiple_inputs_correctly(self, google_encoder, mocker): 67 | mock_embeddings = [ 68 | TextEmbedding( 69 | values=[0.1, 0.2, 0.3], 70 | statistics=TextEmbeddingStatistics(token_count=5, truncated=False), 71 | ), 72 | TextEmbedding( 73 | values=[0.4, 0.5, 0.6], 74 | statistics=TextEmbeddingStatistics(token_count=6, truncated=False), 75 | ), 76 | ] 77 | mocker.patch.object( 78 | google_encoder.client, "get_embeddings", return_value=mock_embeddings 79 | ) 80 | 81 | result = google_encoder(["test1", "test2"]) 82 | assert isinstance(result, list), "Result should be a list" 83 | assert all(isinstance(sublist, list) for sublist in result), ( 84 | "Each item in result should be a list" 85 | ) 86 | google_encoder.client.get_embeddings.assert_called_once() 87 | 88 | def test_raises_value_error_if_project_id_is_none(self, mocker, monkeypatch): 89 | monkeypatch.delenv("GOOGLE_PROJECT_ID", raising=False) 90 | mocker.patch("google.cloud.aiplatform.init") 91 | mocker.patch("vertexai.language_models.TextEmbeddingModel.from_pretrained") 92 | with pytest.raises(ValueError): 93 | GoogleEncoder() 94 | 95 | def test_raises_value_error_if_google_client_fails_to_initialize(self, mocker): 96 | mocker.patch( 97 | "google.cloud.aiplatform.init", 98 | side_effect=Exception("Failed to initialize client"), 99 | ) 100 | with pytest.raises(ValueError): 101 | GoogleEncoder(project_id="test_project_id") 102 | 103 | def test_raises_value_error_if_google_client_is_not_initialized(self, mocker): 104 | mocker.patch("google.cloud.aiplatform.init") 105 | mocker.patch( 106 | "vertexai.language_models.TextEmbeddingModel.from_pretrained", 107 | return_value=None, 108 | ) 109 | encoder = GoogleEncoder(project_id="test_project_id") 110 | with pytest.raises(ValueError): 111 | encoder(["test"]) 112 | 113 | def test_call_method_raises_error_on_api_failure(self, google_encoder, mocker): 114 | mocker.patch.object( 115 | google_encoder.client, 116 | "get_embeddings", 117 | side_effect=GoogleAPICallError("API call failed"), 118 | ) 119 | with pytest.raises(ValueError): 120 | google_encoder(["test"]) 121 | -------------------------------------------------------------------------------- /tests/unit/encoders/test_hfendpointencoder.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from semantic_router.encoders.huggingface import HFEndpointEncoder 4 | 5 | 6 | @pytest.fixture 7 | def encoder(requests_mock): 8 | # Mock the HTTP request made during HFEndpointEncoder initialization 9 | requests_mock.post( 10 | "https://api-inference.huggingface.co/models/bert-base-uncased", 11 | json=[0.1, 0.2, 0.3], 12 | status_code=200, 13 | ) 14 | # Now, when HFEndpointEncoder is initialized, it will use the mocked response above 15 | return HFEndpointEncoder( 16 | huggingface_url="https://api-inference.huggingface.co/models/bert-base-uncased", 17 | huggingface_api_key="test-api-key", 18 | score_threshold=0.8, 19 | ) 20 | 21 | 22 | class TestHFEndpointEncoder: 23 | def test_initialization(self, encoder): 24 | assert ( 25 | encoder.huggingface_url 26 | == "https://api-inference.huggingface.co/models/bert-base-uncased" 27 | ) 28 | assert encoder.huggingface_api_key == "test-api-key" 29 | assert encoder.score_threshold == 0.8 30 | 31 | def test_initialization_failure_no_api_key(self): 32 | with pytest.raises(ValueError) as exc_info: 33 | HFEndpointEncoder( 34 | huggingface_url="https://api-inference.huggingface.co/models/bert-base-uncased" 35 | ) 36 | assert "HuggingFace API key cannot be 'None'" in str(exc_info.value) 37 | 38 | def test_initialization_failure_no_url(self): 39 | with pytest.raises(ValueError) as exc_info: 40 | HFEndpointEncoder(huggingface_api_key="test-api-key") 41 | assert "HuggingFace endpoint url cannot be 'None'" in str(exc_info.value) 42 | 43 | def test_query_success(self, encoder, requests_mock): 44 | requests_mock.post( 45 | "https://api-inference.huggingface.co/models/bert-base-uncased", 46 | json=[0.1, 0.2, 0.3], 47 | status_code=200, 48 | ) 49 | response = encoder.query({"inputs": "Hello World!", "parameters": {}}) 50 | assert response == [0.1, 0.2, 0.3] 51 | 52 | def test_query_failure(self, encoder, requests_mock): 53 | requests_mock.post( 54 | "https://api-inference.huggingface.co/models/bert-base-uncased", 55 | text="Error", 56 | status_code=400, 57 | ) 58 | with pytest.raises(ValueError) as exc_info: 59 | encoder.query({"inputs": "Hello World!", "parameters": {}}) 60 | assert "Query failed with status 400: Error" in str(exc_info.value) 61 | 62 | def test_encode_documents_success(self, encoder, requests_mock): 63 | requests_mock.post( 64 | "https://api-inference.huggingface.co/models/bert-base-uncased", 65 | json=[0.1, 0.2, 0.3], 66 | status_code=200, 67 | ) 68 | embeddings = encoder(["Hello World!"]) 69 | assert embeddings == [[0.1, 0.2, 0.3]] 70 | -------------------------------------------------------------------------------- /tests/unit/encoders/test_huggingface.py: -------------------------------------------------------------------------------- 1 | import os 2 | from unittest.mock import patch 3 | 4 | import numpy as np 5 | import pytest 6 | 7 | _ = pytest.importorskip("transformers") 8 | 9 | from semantic_router.encoders.huggingface import HuggingFaceEncoder # noqa: E402 10 | 11 | test_model_name = "aurelio-ai/sr-test-huggingface" 12 | 13 | 14 | class TestHuggingFaceEncoder: 15 | def test_huggingface_encoder_import_errors_transformers(self): 16 | with patch.dict("sys.modules", {"transformers": None}): 17 | with pytest.raises(ImportError) as error: 18 | HuggingFaceEncoder() 19 | 20 | assert "Please install transformers to use HuggingFaceEncoder" in str( 21 | error.value 22 | ) 23 | 24 | def test_huggingface_encoder_import_errors_torch(self): 25 | with patch.dict("sys.modules", {"torch": None}): 26 | with pytest.raises(ImportError) as error: 27 | HuggingFaceEncoder() 28 | 29 | assert "Please install transformers to use HuggingFaceEncoder" in str( 30 | error.value 31 | ) 32 | 33 | @pytest.mark.skipif( 34 | os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run" 35 | ) 36 | def test_huggingface_encoder_mean_pooling(self): 37 | encoder = HuggingFaceEncoder(name=test_model_name) 38 | test_docs = ["This is a test", "This is another test"] 39 | embeddings = encoder(test_docs, pooling_strategy="mean") 40 | assert isinstance(embeddings, list) 41 | assert len(embeddings) == len(test_docs) 42 | assert all(isinstance(embedding, list) for embedding in embeddings) 43 | assert all(len(embedding) > 0 for embedding in embeddings) 44 | 45 | @pytest.mark.skipif( 46 | os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run" 47 | ) 48 | def test_huggingface_encoder_max_pooling(self): 49 | encoder = HuggingFaceEncoder(name=test_model_name) 50 | test_docs = ["This is a test", "This is another test"] 51 | embeddings = encoder(test_docs, pooling_strategy="max") 52 | assert isinstance(embeddings, list) 53 | assert len(embeddings) == len(test_docs) 54 | assert all(isinstance(embedding, list) for embedding in embeddings) 55 | assert all(len(embedding) > 0 for embedding in embeddings) 56 | 57 | @pytest.mark.skipif( 58 | os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run" 59 | ) 60 | def test_huggingface_encoder_normalized_embeddings(self): 61 | encoder = HuggingFaceEncoder(name=test_model_name) 62 | docs = ["This is a test document.", "Another test document."] 63 | unnormalized_embeddings = encoder(docs, normalize_embeddings=False) 64 | normalized_embeddings = encoder(docs, normalize_embeddings=True) 65 | assert len(unnormalized_embeddings) == len(normalized_embeddings) 66 | 67 | for unnormalized, normalized in zip( 68 | unnormalized_embeddings, normalized_embeddings 69 | ): 70 | norm_unnormalized = np.linalg.norm(unnormalized, ord=2) 71 | norm_normalized = np.linalg.norm(normalized, ord=2) 72 | # Ensure the norm of the normalized embeddings is approximately 1 73 | assert np.isclose(norm_normalized, 1.0) 74 | # Ensure the normalized embeddings are actually normalized versions of unnormalized embeddings 75 | np.testing.assert_allclose( 76 | normalized, 77 | np.divide(unnormalized, norm_unnormalized), 78 | rtol=1e-5, 79 | atol=1e-5, # Adjust tolerance levels 80 | ) 81 | -------------------------------------------------------------------------------- /tests/unit/encoders/test_lite_encoders.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import litellm 4 | import pytest 5 | from litellm.types.utils import Embedding 6 | 7 | from semantic_router.encoders import ( 8 | CohereEncoder, 9 | JinaEncoder, 10 | LiteLLMEncoder, 11 | MistralEncoder, 12 | NimEncoder, 13 | VoyageEncoder, 14 | ) 15 | 16 | matrix = [ 17 | [ 18 | "openai", 19 | "openai/text-embedding-3-small", 20 | "text-embedding-3-small", 21 | "OPENAI_API_KEY", 22 | LiteLLMEncoder, 23 | ], 24 | [ 25 | "cohere", 26 | "embed-english-v3.0", 27 | "embed-english-v3.0", 28 | "COHERE_API_KEY", 29 | CohereEncoder, 30 | ], 31 | [ 32 | "mistral", 33 | "mistral-embed", 34 | "mistral-embed", 35 | "MISTRAL_API_KEY", 36 | MistralEncoder, 37 | ], 38 | [ 39 | "jina_ai", 40 | "jina-embeddings-v3", 41 | "jina-embeddings-v3", 42 | "JINA_AI_API_KEY", 43 | JinaEncoder, 44 | ], 45 | [ 46 | "voyage", 47 | "voyage-3", 48 | "voyage-3", 49 | "VOYAGE_API_KEY", 50 | VoyageEncoder, 51 | ], 52 | [ 53 | "nvidia_nim", 54 | "nv-embedqa-e5-v5", 55 | "nv-embedqa-e5-v5", 56 | "NVIDIA_NIM_API_KEY", 57 | NimEncoder, 58 | ], 59 | ] 60 | 61 | 62 | @pytest.fixture 63 | def mock_litellm(mocker): 64 | mock_embed = litellm.EmbeddingResponse( 65 | data=[ 66 | Embedding(embedding=[0.1, 0.2, 0.3], index=0, object="embedding"), 67 | ] 68 | ) 69 | mocker.patch.object(litellm, "embedding", return_value=mock_embed) 70 | return mock_embed 71 | 72 | 73 | @pytest.mark.parametrize( 74 | "provider, model_in, model_name, api_key_env_var, encoder", matrix 75 | ) 76 | class TestEncoders: 77 | def test_initialization_with_api_key( 78 | self, provider, model_in, model_name, api_key_env_var, encoder 79 | ): 80 | os.environ[api_key_env_var] = "test_api_key" 81 | enc = encoder(model_in) 82 | assert enc.name == model_name, "Default name not set correctly" 83 | assert enc.type == provider, "Default type/provider not set correctly" 84 | 85 | def test_initialization_without_api_key( 86 | self, monkeypatch, provider, model_in, model_name, api_key_env_var, encoder 87 | ): 88 | monkeypatch.delenv(api_key_env_var, raising=False) 89 | with pytest.raises(ValueError): 90 | encoder() 91 | 92 | def test_call_method( 93 | self, mock_litellm, provider, model_in, model_name, api_key_env_var, encoder 94 | ): 95 | os.environ[api_key_env_var] = "test_api_key" 96 | result = encoder(model_in)(["test"]) 97 | assert isinstance(result, list), "Result should be a list" 98 | assert all(isinstance(sublist, list) for sublist in result), ( 99 | "Each item in result should be a list" 100 | ) 101 | litellm.embedding.assert_called_once() 102 | 103 | def test_returns_list_of_embeddings_for_valid_input( 104 | self, mock_litellm, provider, model_in, model_name, api_key_env_var, encoder 105 | ): 106 | os.environ[api_key_env_var] = "test_api_key" 107 | result = encoder(model_in)(["test"]) 108 | assert isinstance(result, list), "Result should be a list" 109 | assert all(isinstance(sublist, list) for sublist in result), ( 110 | "Each item in result should be a list" 111 | ) 112 | litellm.embedding.assert_called_once() 113 | 114 | def test_handles_multiple_inputs_correctly( 115 | self, mocker, provider, model_in, model_name, api_key_env_var, encoder 116 | ): 117 | os.environ[api_key_env_var] = "test_api_key" 118 | mock_embed = litellm.EmbeddingResponse( 119 | data=[ 120 | Embedding(embedding=[0.1, 0.2, 0.3], index=0, object="embedding"), 121 | Embedding(embedding=[0.4, 0.5, 0.6], index=1, object="embedding"), 122 | ] 123 | ) 124 | mocker.patch.object(litellm, "embedding", return_value=mock_embed) 125 | 126 | result = encoder(model_in)(["test1", "test2"]) 127 | assert isinstance(result, list), "Result should be a list" 128 | assert all(isinstance(sublist, list) for sublist in result), ( 129 | "Each item in result should be a list" 130 | ) 131 | litellm.embedding.assert_called_once() 132 | 133 | def test_call_method_raises_error_on_api_failure( 134 | self, mocker, provider, model_in, model_name, api_key_env_var, encoder 135 | ): 136 | os.environ[api_key_env_var] = "test_api_key" 137 | mocker.patch.object( 138 | litellm, "embedding", side_effect=Exception("API call failed") 139 | ) 140 | with pytest.raises(ValueError): 141 | encoder(model_in)(["test"]) 142 | -------------------------------------------------------------------------------- /tests/unit/encoders/test_local.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from semantic_router.encoders import LocalEncoder 4 | 5 | _ = pytest.importorskip("sentence_transformers") 6 | 7 | 8 | class TestLocalEncoder: 9 | def test_local_encoder(self): 10 | encoder = LocalEncoder() 11 | test_docs = ["This is a test", "This is another test"] 12 | embeddings = encoder(test_docs) 13 | assert isinstance(embeddings, list) 14 | assert len(embeddings) == len(test_docs) 15 | assert all(isinstance(embedding, list) for embedding in embeddings) 16 | -------------------------------------------------------------------------------- /tests/unit/encoders/test_ollama.py: -------------------------------------------------------------------------------- 1 | import os 2 | from unittest.mock import Mock, patch 3 | 4 | import pytest 5 | 6 | pytest.importorskip("ollama") 7 | 8 | from semantic_router.encoders.ollama import OllamaEncoder 9 | 10 | 11 | @pytest.fixture(autouse=True, scope="session") 12 | def set_pinecone_api_key(): 13 | os.environ["PINECONE_API_KEY"] = "test" 14 | 15 | 16 | @pytest.fixture 17 | def mock_ollama_client(): 18 | with patch("ollama.Client") as mock_client: 19 | yield mock_client 20 | 21 | 22 | class TestOllamaEncoder: 23 | def test_ollama_encoder_init_success(self, mocker): 24 | mocker.patch("ollama.Client", return_value=Mock()) 25 | encoder = OllamaEncoder(base_url="http://localhost:11434") 26 | assert encoder.client is not None 27 | assert encoder.type == "ollama" 28 | 29 | def test_ollama_encoder_init_import_error(self, mocker): 30 | mocker.patch.dict("sys.modules", {"ollama": None}) 31 | with patch( 32 | "builtins.__import__", side_effect=ImportError("No module named 'ollama'") 33 | ): 34 | with pytest.raises(ImportError): 35 | OllamaEncoder(base_url="http://localhost:11434") 36 | 37 | def test_ollama_encoder_call_success(self, mocker): 38 | mock_client = Mock() 39 | mock_embed_result = Mock() 40 | mock_embed_result.embeddings = [[0.1, 0.2], [0.3, 0.4]] 41 | mock_client.embed.return_value = mock_embed_result 42 | mocker.patch("ollama.Client", return_value=mock_client) 43 | encoder = OllamaEncoder(base_url="http://localhost:11434") 44 | encoder.client = mock_client 45 | docs = ["doc1", "doc2"] 46 | result = encoder(docs) 47 | assert result == [[0.1, 0.2], [0.3, 0.4]] 48 | mock_client.embed.assert_called_once_with(model=encoder.name, input=docs) 49 | 50 | def test_ollama_encoder_call_client_not_initialized(self, mocker): 51 | encoder = OllamaEncoder(base_url="http://localhost:11434") 52 | encoder.client = None 53 | with pytest.raises(ValueError) as e: 54 | encoder(["doc1"]) 55 | assert "OLLAMA Platform client is not initialized." in str(e.value) 56 | 57 | def test_ollama_encoder_call_api_error(self, mocker): 58 | mock_client = Mock() 59 | mock_client.embed.side_effect = Exception("API error") 60 | mocker.patch("ollama.Client", return_value=mock_client) 61 | encoder = OllamaEncoder(base_url="http://localhost:11434") 62 | encoder.client = mock_client 63 | with pytest.raises(ValueError) as e: 64 | encoder(["doc1"]) 65 | assert "OLLAMA API call failed. Error: API error" in str(e.value) 66 | 67 | def test_ollama_encoder_uses_env_base_url(self, mocker): 68 | test_url = "http://env-ollama:1234" 69 | mock_client = Mock() 70 | mock_client.host = test_url # Set the host attribute on the mock 71 | mocker.patch("ollama.Client", return_value=mock_client) 72 | with patch.dict(os.environ, {"OLLAMA_BASE_URL": test_url}): 73 | encoder = OllamaEncoder() 74 | assert encoder.client is not None 75 | assert encoder.client.host == test_url 76 | -------------------------------------------------------------------------------- /tests/unit/encoders/test_sparse_sentence_transformer.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from semantic_router.encoders import LocalSparseEncoder 4 | from semantic_router.schema import SparseEmbedding 5 | 6 | _ = pytest.importorskip("sentence_transformers") 7 | 8 | 9 | class TestLocalSparseEncoder: 10 | def test_sparse_local_encoder(self): 11 | # Use a public SPLADE model for testing 12 | encoder = LocalSparseEncoder(name="naver/splade-cocondenser-ensembledistil") 13 | test_docs = ["This is a test", "This is another test"] 14 | embeddings = encoder(test_docs) 15 | assert isinstance(embeddings, list) 16 | assert len(embeddings) == len(test_docs) 17 | assert all(isinstance(embedding, SparseEmbedding) for embedding in embeddings) 18 | -------------------------------------------------------------------------------- /tests/unit/encoders/test_tfidf.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from semantic_router.encoders import TfidfEncoder 5 | from semantic_router.route import Route 6 | 7 | 8 | @pytest.fixture 9 | def tfidf_encoder(): 10 | return TfidfEncoder() 11 | 12 | 13 | class TestTfidfEncoder: 14 | def test_initialization(self, tfidf_encoder): 15 | assert tfidf_encoder.word_index == {} 16 | assert (tfidf_encoder.idf == np.array([])).all() 17 | 18 | def test_fit(self, tfidf_encoder): 19 | routes = [ 20 | Route( 21 | name="test_route", 22 | utterances=["some docs", "and more docs", "and even more docs"], 23 | ) 24 | ] 25 | tfidf_encoder.fit(routes) 26 | assert tfidf_encoder.word_index != {} 27 | assert not np.array_equal(tfidf_encoder.idf, np.array([])) 28 | 29 | def test_call_method(self, tfidf_encoder): 30 | routes = [ 31 | Route( 32 | name="test_route", 33 | utterances=["some docs", "and more docs", "and even more docs"], 34 | ) 35 | ] 36 | tfidf_encoder.fit(routes) 37 | result = tfidf_encoder(["test"]) 38 | assert isinstance(result, list), "Result should be a list" 39 | assert all( 40 | isinstance(sparse_emb.embedding, np.ndarray) for sparse_emb in result 41 | ), "Each item in result should be an array" 42 | 43 | def test_call_method_no_docs_tfidf_encoder(self, tfidf_encoder): 44 | with pytest.raises(ValueError): 45 | tfidf_encoder([]) 46 | 47 | def test_call_method_no_word(self, tfidf_encoder): 48 | routes = [ 49 | Route( 50 | name="test_route", 51 | utterances=["some docs", "and more docs", "and even more docs"], 52 | ) 53 | ] 54 | tfidf_encoder.fit(routes) 55 | result = tfidf_encoder(["doc with fake word gta5jabcxyz"]) 56 | assert isinstance(result, list), "Result should be a list" 57 | assert all( 58 | isinstance(sparse_emb.embedding, np.ndarray) for sparse_emb in result 59 | ), "Each item in result should be an array" 60 | 61 | def test_fit_with_strings(self, tfidf_encoder): 62 | routes = ["test a", "test b", "test c"] 63 | with pytest.raises(TypeError): 64 | tfidf_encoder.fit(routes) 65 | 66 | def test_call_method_with_uninitialized_model(self, tfidf_encoder): 67 | with pytest.raises(ValueError): 68 | tfidf_encoder(["test"]) 69 | 70 | def test_compute_tf_no_word_index(self, tfidf_encoder): 71 | with pytest.raises(ValueError, match="Word index is not initialized."): 72 | tfidf_encoder._compute_tf(["some docs"]) 73 | 74 | def test_compute_tf_with_word_in_word_index(self, tfidf_encoder): 75 | routes = [ 76 | Route( 77 | name="test_route", 78 | utterances=["some docs", "and more docs", "and even more docs"], 79 | ) 80 | ] 81 | tfidf_encoder.fit(routes) 82 | tf = tfidf_encoder._compute_tf(["some docs"]) 83 | assert tf.shape == (1, len(tfidf_encoder.word_index)) 84 | 85 | def test_compute_idf_no_word_index(self, tfidf_encoder): 86 | with pytest.raises(ValueError, match="Word index is not initialized."): 87 | tfidf_encoder._compute_idf(["some docs"]) 88 | -------------------------------------------------------------------------------- /tests/unit/encoders/test_vit.py: -------------------------------------------------------------------------------- 1 | import os 2 | from unittest.mock import patch 3 | 4 | import numpy as np 5 | import pytest 6 | 7 | _ = pytest.importorskip("torch") 8 | 9 | import torch # noqa: E402 10 | from PIL import Image # noqa: E402 11 | 12 | from semantic_router.encoders import VitEncoder # noqa: E402 13 | 14 | test_model_name = "aurelio-ai/sr-test-vit" 15 | embed_dim = 32 16 | 17 | if torch.cuda.is_available(): 18 | device = "cuda" 19 | elif torch.backends.mps.is_available(): 20 | device = "mps" 21 | else: 22 | device = "cpu" 23 | 24 | 25 | @pytest.fixture() 26 | def dummy_pil_image(): 27 | return Image.fromarray(np.random.rand(1024, 512, 3).astype(np.uint8)) 28 | 29 | 30 | @pytest.fixture() 31 | def dummy_black_and_white_img(): 32 | return Image.fromarray(np.random.rand(224, 224, 2).astype(np.uint8)) 33 | 34 | 35 | @pytest.fixture() 36 | def misshaped_pil_image(): 37 | return Image.fromarray(np.random.rand(64, 64, 3).astype(np.uint8)) 38 | 39 | 40 | class TestVitEncoder: 41 | def test_vit_encoder__import_errors_transformers(self): 42 | with patch.dict("sys.modules", {"transformers": None}): 43 | with pytest.raises(ImportError) as error: 44 | VitEncoder() 45 | 46 | assert "Please install transformers to use VitEncoder" in str(error.value) 47 | 48 | @pytest.mark.skipif( 49 | os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run" 50 | ) 51 | def test_vit_encoder_initialization(self): 52 | vit_encoder = VitEncoder(name=test_model_name) 53 | assert vit_encoder.name == test_model_name 54 | assert vit_encoder.type == "huggingface" 55 | assert vit_encoder.score_threshold == 0.5 56 | assert vit_encoder.device == device 57 | 58 | @pytest.mark.skipif( 59 | os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run" 60 | ) 61 | def test_vit_encoder_call(self, dummy_pil_image): 62 | vit_encoder = VitEncoder(name=test_model_name) 63 | encoded_images = vit_encoder([dummy_pil_image] * 3) 64 | 65 | assert len(encoded_images) == 3 66 | assert set(map(len, encoded_images)) == {embed_dim} 67 | 68 | @pytest.mark.skipif( 69 | os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run" 70 | ) 71 | def test_vit_encoder_call_misshaped(self, dummy_pil_image, misshaped_pil_image): 72 | vit_encoder = VitEncoder(name=test_model_name) 73 | encoded_images = vit_encoder([dummy_pil_image, misshaped_pil_image]) 74 | 75 | assert len(encoded_images) == 2 76 | assert set(map(len, encoded_images)) == {embed_dim} 77 | 78 | @pytest.mark.skipif( 79 | os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run" 80 | ) 81 | def test_vit_encoder_process_images_device(self, dummy_pil_image): 82 | vit_encoder = VitEncoder(name=test_model_name) 83 | imgs = vit_encoder._process_images([dummy_pil_image] * 3)["pixel_values"] 84 | 85 | assert imgs.device.type == device 86 | 87 | @pytest.mark.skipif( 88 | os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run" 89 | ) 90 | def test_vit_encoder_ensure_rgb(self, dummy_black_and_white_img): 91 | vit_encoder = VitEncoder(name=test_model_name) 92 | rgb_image = vit_encoder._ensure_rgb(dummy_black_and_white_img) 93 | 94 | assert rgb_image.mode == "RGB" 95 | assert np.array(rgb_image).shape == (224, 224, 3) 96 | -------------------------------------------------------------------------------- /tests/unit/llms/test_llm_azure_openai.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from semantic_router.llms import AzureOpenAILLM 4 | from semantic_router.schema import Message 5 | 6 | 7 | @pytest.fixture 8 | def azure_openai_llm(mocker): 9 | mocker.patch("openai.Client") 10 | return AzureOpenAILLM(openai_api_key="test_api_key", azure_endpoint="test_endpoint") 11 | 12 | 13 | class TestOpenAILLM: 14 | def test_azure_openai_llm_init_with_api_key(self, azure_openai_llm): 15 | assert azure_openai_llm._client is not None, "Client should be initialized" 16 | assert azure_openai_llm.name == "gpt-4o", "Default name not set correctly" 17 | 18 | def test_azure_openai_llm_init_success(self, mocker): 19 | mocker.patch("os.getenv", return_value="fake-api-key") 20 | llm = AzureOpenAILLM() 21 | assert llm._client is not None 22 | 23 | def test_azure_openai_llm_init_without_api_key(self, mocker): 24 | mocker.patch("os.getenv", return_value=None) 25 | with pytest.raises(ValueError) as _: 26 | AzureOpenAILLM() 27 | 28 | # def test_azure_openai_llm_init_without_azure_endpoint(self, mocker): 29 | # mocker.patch("os.getenv", side_effect=[None, "fake-api-key"]) 30 | # with pytest.raises(ValueError) as e: 31 | # AzureOpenAILLM(openai_api_key="test_api_key") 32 | # assert "Azure endpoint API key cannot be 'None'." in str(e.value) 33 | 34 | def test_azure_openai_llm_init_without_azure_endpoint(self, mocker): 35 | mocker.patch( 36 | "os.getenv", 37 | side_effect=lambda key, default=None: { 38 | "OPENAI_CHAT_MODEL_NAME": "test-model-name" 39 | }.get(key, default), 40 | ) 41 | with pytest.raises(ValueError) as e: 42 | AzureOpenAILLM(openai_api_key="test_api_key") 43 | assert "Azure endpoint API key cannot be 'None'" in str(e.value) 44 | 45 | def test_azure_openai_llm_call_uninitialized_client(self, azure_openai_llm): 46 | # Set the client to None to simulate an uninitialized client 47 | azure_openai_llm._client = None 48 | with pytest.raises(ValueError) as e: 49 | llm_input = [Message(role="user", content="test")] 50 | azure_openai_llm(llm_input) 51 | assert "AzureOpenAI client is not initialized." in str(e.value) 52 | 53 | def test_azure_openai_llm_init_exception(self, mocker): 54 | mocker.patch("os.getenv", return_value="fake-api-key") 55 | mocker.patch( 56 | "openai.AzureOpenAI", side_effect=Exception("Initialization error") 57 | ) 58 | with pytest.raises(ValueError) as e: 59 | AzureOpenAILLM() 60 | assert ( 61 | "AzureOpenAI API client failed to initialize. Error: Initialization error" 62 | in str(e.value) 63 | ) 64 | 65 | def test_azure_openai_llm_temperature_max_tokens_initialization(self): 66 | test_temperature = 0.5 67 | test_max_tokens = 100 68 | azure_llm = AzureOpenAILLM( 69 | openai_api_key="test_api_key", 70 | azure_endpoint="test_endpoint", 71 | temperature=test_temperature, 72 | max_tokens=test_max_tokens, 73 | ) 74 | 75 | assert azure_llm.temperature == test_temperature, ( 76 | "Temperature not set correctly" 77 | ) 78 | assert azure_llm.max_tokens == test_max_tokens, "Max tokens not set correctly" 79 | 80 | def test_azure_openai_llm_call_success(self, azure_openai_llm, mocker): 81 | mock_completion = mocker.MagicMock() 82 | mock_completion.choices[0].message.content = "test" 83 | 84 | mocker.patch("os.getenv", return_value="fake-api-key") 85 | mocker.patch.object( 86 | azure_openai_llm._client.chat.completions, 87 | "create", 88 | return_value=mock_completion, 89 | ) 90 | llm_input = [Message(role="user", content="test")] 91 | output = azure_openai_llm(llm_input) 92 | assert output == "test" 93 | -------------------------------------------------------------------------------- /tests/unit/llms/test_llm_cohere.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from semantic_router.llms import CohereLLM 4 | from semantic_router.schema import Message 5 | 6 | 7 | @pytest.fixture 8 | def cohere_llm(mocker): 9 | mocker.patch("cohere.Client") 10 | return CohereLLM(cohere_api_key="test_api_key") 11 | 12 | 13 | class TestCohereLLM: 14 | def test_initialization_with_api_key(self, cohere_llm): 15 | assert cohere_llm._client is not None, "Client should be initialized" 16 | assert cohere_llm.name == "command", "Default name not set correctly" 17 | 18 | def test_initialization_without_api_key(self, mocker, monkeypatch): 19 | monkeypatch.delenv("COHERE_API_KEY", raising=False) 20 | mocker.patch("cohere.Client") 21 | with pytest.raises(ValueError): 22 | CohereLLM() 23 | 24 | def test_call_method(self, cohere_llm, mocker): 25 | mock_llm = mocker.MagicMock() 26 | mock_llm.text = "test" 27 | cohere_llm._client.chat.return_value = mock_llm 28 | 29 | llm_input = [Message(role="user", content="test")] 30 | result = cohere_llm(llm_input) 31 | assert isinstance(result, str), "Result should be a str" 32 | cohere_llm._client.chat.assert_called_once() 33 | 34 | def test_raises_value_error_if_cohere_client_fails_to_initialize(self, mocker): 35 | mocker.patch( 36 | "cohere.Client", side_effect=Exception("Failed to initialize client") 37 | ) 38 | with pytest.raises(ValueError): 39 | CohereLLM(cohere_api_key="test_api_key") 40 | 41 | def test_raises_value_error_if_cohere_client_is_not_initialized(self, mocker): 42 | mocker.patch("cohere.Client", return_value=None) 43 | llm = CohereLLM(cohere_api_key="test_api_key") 44 | with pytest.raises(ValueError): 45 | llm("test") 46 | 47 | def test_call_method_raises_error_on_api_failure(self, cohere_llm, mocker): 48 | mocker.patch.object( 49 | cohere_llm._client, "__call__", side_effect=Exception("API call failed") 50 | ) 51 | with pytest.raises(ValueError): 52 | cohere_llm("test") 53 | -------------------------------------------------------------------------------- /tests/unit/llms/test_llm_llamacpp.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import patch 2 | 3 | import pytest 4 | 5 | _ = pytest.importorskip("llama_cpp") 6 | 7 | from llama_cpp import Llama # noqa: E402 8 | 9 | from semantic_router.llms.llamacpp import LlamaCppLLM # noqa: E402 10 | from semantic_router.schema import Message # noqa: E402 11 | 12 | 13 | @pytest.fixture 14 | def llamacpp_llm(mocker): 15 | mock_llama = mocker.patch("llama_cpp.Llama", spec=Llama) 16 | llm = mock_llama.return_value 17 | return LlamaCppLLM(llm=llm) 18 | 19 | 20 | class TestLlamaCppLLM: 21 | def test_llama_cpp_import_errors(self, llamacpp_llm): 22 | with patch.dict("sys.modules", {"llama_cpp": None}): 23 | with pytest.raises(ImportError) as error: 24 | LlamaCppLLM(llamacpp_llm.llm) 25 | 26 | assert ( 27 | "Please install LlamaCPP to use Llama CPP llm. " 28 | "You can install it with: " 29 | "`pip install 'semantic-router[local]'`" in str(error.value) 30 | ) 31 | 32 | def test_llamacpp_llm_init_success(self, llamacpp_llm): 33 | assert llamacpp_llm.name == "llama.cpp" 34 | assert llamacpp_llm.temperature == 0.2 35 | assert llamacpp_llm.max_tokens == 200 36 | assert llamacpp_llm.llm is not None 37 | 38 | def test_llamacpp_llm_call_success(self, llamacpp_llm, mocker): 39 | llamacpp_llm.llm.create_chat_completion = mocker.Mock( 40 | return_value={"choices": [{"message": {"content": "test"}}]} 41 | ) 42 | 43 | llm_input = [Message(role="user", content="test")] 44 | output = llamacpp_llm(llm_input) 45 | assert output == "test" 46 | 47 | def test_llamacpp_llm_grammar(self, llamacpp_llm): 48 | llamacpp_llm._grammar() 49 | 50 | def test_llamacpp_extract_function_inputs(self, llamacpp_llm, mocker): 51 | llamacpp_llm.llm.create_chat_completion = mocker.Mock( 52 | return_value={ 53 | "choices": [ 54 | {"message": {"content": "{'timezone': 'America/New_York'}"}} 55 | ] 56 | } 57 | ) 58 | test_schema = { 59 | "name": "get_time", 60 | "description": 'Finds the current time in a specific timezone.\n\n:param timezone: The timezone to find the current time in, should\n be a valid timezone from the IANA Time Zone Database like\n "America/New_York" or "Europe/London". Do NOT put the place\n name itself like "rome", or "new york", you must provide\n the IANA format.\n:type timezone: str\n:return: The current time in the specified timezone.', 61 | "signature": "(timezone: str) -> str", 62 | "output": "", 63 | } 64 | test_query = "What time is it in America/New_York?" 65 | 66 | llamacpp_llm.extract_function_inputs( 67 | query=test_query, function_schemas=[test_schema] 68 | ) 69 | 70 | def test_llamacpp_extract_function_inputs_invalid(self, llamacpp_llm, mocker): 71 | with pytest.raises(ValueError): 72 | llamacpp_llm.llm.create_chat_completion = mocker.Mock( 73 | return_value={ 74 | "choices": [ 75 | {"message": {"content": "{'time': 'America/New_York'}"}} 76 | ] 77 | } 78 | ) 79 | test_schema = { 80 | "name": "get_time", 81 | "description": 'Finds the current time in a specific timezone.\n\n:param timezone: The timezone to find the current time in, should\n be a valid timezone from the IANA Time Zone Database like\n "America/New_York" or "Europe/London". Do NOT put the place\n name itself like "rome", or "new york", you must provide\n the IANA format.\n:type timezone: str\n:return: The current time in the specified timezone.', 82 | "signature": "(timezone: str) -> str", 83 | "output": "", 84 | } 85 | test_query = "What time is it in America/New_York?" 86 | 87 | llamacpp_llm.extract_function_inputs( 88 | query=test_query, function_schemas=[test_schema] 89 | ) 90 | -------------------------------------------------------------------------------- /tests/unit/llms/test_llm_mistral.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import patch 2 | 3 | import pytest 4 | 5 | from semantic_router.llms import MistralAILLM 6 | from semantic_router.schema import Message 7 | 8 | 9 | @pytest.fixture 10 | def mistralai_llm(mocker): 11 | mocker.patch("mistralai.client.MistralClient") 12 | return MistralAILLM(mistralai_api_key="test_api_key") 13 | 14 | 15 | class TestMistralAILLM: 16 | def test_mistral_llm_import_errors(self): 17 | with patch.dict("sys.modules", {"mistralai": None}): 18 | with pytest.raises(ImportError) as error: 19 | MistralAILLM() 20 | 21 | assert ( 22 | "Please install MistralAI to use MistralAI LLM. " 23 | "You can install it with: " 24 | "`pip install 'semantic-router[mistralai]'`" in str(error.value) 25 | ) 26 | 27 | def test_mistralai_llm_init_with_api_key(self, mistralai_llm): 28 | assert mistralai_llm._client is not None, "Client should be initialized" 29 | assert mistralai_llm.name == "mistral-tiny", "Default name not set correctly" 30 | 31 | def test_mistralai_llm_init_success(self, mocker): 32 | mocker.patch("os.getenv", return_value="fake-api-key") 33 | llm = MistralAILLM() 34 | assert llm._client is not None 35 | 36 | def test_mistralai_llm_init_without_api_key(self, mocker): 37 | mocker.patch("os.getenv", return_value=None) 38 | with pytest.raises(ValueError) as _: 39 | MistralAILLM() 40 | 41 | def test_mistralai_llm_call_uninitialized_client(self, mistralai_llm): 42 | # Set the client to None to simulate an uninitialized client 43 | mistralai_llm._client = None 44 | with pytest.raises(ValueError) as e: 45 | llm_input = [Message(role="user", content="test")] 46 | mistralai_llm(llm_input) 47 | assert "MistralAI client is not initialized." in str(e.value) 48 | 49 | def test_mistralai_llm_init_exception(self, mocker): 50 | mocker.patch( 51 | "mistralai.client.MistralClient", 52 | side_effect=Exception("Initialization error"), 53 | ) 54 | with pytest.raises(ValueError) as e: 55 | MistralAILLM() 56 | assert "MistralAI API key cannot be 'None'." in str(e.value) 57 | 58 | def test_mistralai_llm_call_success(self, mistralai_llm, mocker): 59 | mock_completion = mocker.MagicMock() 60 | mock_completion.choices[0].message.content = "test" 61 | 62 | mocker.patch("os.getenv", return_value="fake-api-key") 63 | mocker.patch.object( 64 | mistralai_llm._client, 65 | "chat", 66 | return_value=mock_completion, 67 | ) 68 | llm_input = [Message(role="user", content="test")] 69 | output = mistralai_llm(llm_input) 70 | assert output == "test" 71 | -------------------------------------------------------------------------------- /tests/unit/llms/test_llm_ollama.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from semantic_router.llms.ollama import OllamaLLM 4 | from semantic_router.schema import Message 5 | 6 | 7 | @pytest.fixture 8 | def ollama_llm(): 9 | return OllamaLLM() 10 | 11 | 12 | class TestOllamaLLM: 13 | def test_ollama_llm_init_success(self, ollama_llm): 14 | assert ollama_llm.temperature == 0.2 15 | assert ollama_llm.name == "openhermes" 16 | assert ollama_llm.max_tokens == 200 17 | assert ollama_llm.stream is False 18 | 19 | def test_ollama_llm_call_success(self, ollama_llm, mocker): 20 | mock_response = mocker.MagicMock() 21 | mock_response.json.return_value = {"message": {"content": "test response"}} 22 | mocker.patch("requests.post", return_value=mock_response) 23 | 24 | output = ollama_llm([Message(role="user", content="test")]) 25 | assert output == "test response" 26 | 27 | def test_ollama_llm_error_handling(self, ollama_llm, mocker): 28 | mocker.patch("requests.post", side_effect=Exception("LLM error")) 29 | with pytest.raises(Exception) as exc_info: 30 | ollama_llm([Message(role="user", content="test")]) 31 | assert "LLM error" in str(exc_info.value) 32 | -------------------------------------------------------------------------------- /tests/unit/llms/test_llm_openrouter.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from semantic_router.llms import OpenRouterLLM 4 | from semantic_router.schema import Message 5 | 6 | 7 | @pytest.fixture 8 | def openrouter_llm(mocker): 9 | mocker.patch("openai.Client") 10 | return OpenRouterLLM(openrouter_api_key="test_api_key") 11 | 12 | 13 | class TestOpenRouterLLM: 14 | def test_openrouter_llm_init_with_api_key(self, openrouter_llm): 15 | assert openrouter_llm._client is not None, "Client should be initialized" 16 | assert openrouter_llm.name == "mistralai/mistral-7b-instruct", ( 17 | "Default name not set correctly" 18 | ) 19 | 20 | def test_openrouter_llm_init_success(self, mocker): 21 | mocker.patch("os.getenv", return_value="fake-api-key") 22 | llm = OpenRouterLLM() 23 | assert llm._client is not None 24 | 25 | def test_openrouter_llm_init_without_api_key(self, mocker): 26 | mocker.patch("os.getenv", return_value=None) 27 | with pytest.raises(ValueError) as _: 28 | OpenRouterLLM() 29 | 30 | def test_openrouter_llm_call_uninitialized_client(self, openrouter_llm): 31 | # Set the client to None to simulate an uninitialized client 32 | openrouter_llm._client = None 33 | with pytest.raises(ValueError) as e: 34 | llm_input = [Message(role="user", content="test")] 35 | openrouter_llm(llm_input) 36 | assert "OpenRouter client is not initialized." in str(e.value) 37 | 38 | def test_openrouter_llm_init_exception(self, mocker): 39 | mocker.patch("os.getenv", return_value="fake-api-key") 40 | mocker.patch("openai.OpenAI", side_effect=Exception("Initialization error")) 41 | with pytest.raises(ValueError) as e: 42 | OpenRouterLLM() 43 | assert ( 44 | "OpenRouter API client failed to initialize. Error: Initialization error" 45 | in str(e.value) 46 | ) 47 | 48 | def test_openrouter_llm_call_success(self, openrouter_llm, mocker): 49 | mock_completion = mocker.MagicMock() 50 | mock_completion.choices[0].message.content = "test" 51 | 52 | mocker.patch("os.getenv", return_value="fake-api-key") 53 | mocker.patch.object( 54 | openrouter_llm._client.chat.completions, 55 | "create", 56 | return_value=mock_completion, 57 | ) 58 | llm_input = [Message(role="user", content="test")] 59 | output = openrouter_llm(llm_input) 60 | assert output == "test" 61 | -------------------------------------------------------------------------------- /tests/unit/test_function_schema.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | from semantic_router.utils.function_call import FunctionSchema 4 | 5 | 6 | def scrape_webpage(url: str, name: str = "test") -> str: 7 | """Provides access to web scraping. You can use this tool to scrape a webpage. 8 | Many webpages may return no information due to JS or adblock issues, if this 9 | happens, you must use a different URL. 10 | """ 11 | return "hello there" 12 | 13 | 14 | def test_function_schema(): 15 | schema = FunctionSchema(scrape_webpage) 16 | assert schema.name == scrape_webpage.__name__ 17 | assert schema.description == str(inspect.getdoc(scrape_webpage)) 18 | assert schema.signature == str(inspect.signature(scrape_webpage)) 19 | assert schema.output == str(inspect.signature(scrape_webpage).return_annotation) 20 | assert len(schema.parameters) == 2 21 | 22 | 23 | def test_ollama_function_schema(): 24 | schema = FunctionSchema(scrape_webpage) 25 | ollama_schema = schema.to_ollama() 26 | assert ollama_schema["type"] == "function" 27 | assert ollama_schema["function"]["name"] == schema.name 28 | assert ollama_schema["function"]["description"] == schema.description 29 | assert ollama_schema["function"]["parameters"]["type"] == "object" 30 | assert ( 31 | ollama_schema["function"]["parameters"]["properties"]["url"]["type"] == "string" 32 | ) 33 | assert ( 34 | ollama_schema["function"]["parameters"]["properties"]["name"]["type"] 35 | == "string" 36 | ) 37 | assert ( 38 | ollama_schema["function"]["parameters"]["properties"]["url"]["description"] 39 | is None 40 | ) 41 | assert ( 42 | ollama_schema["function"]["parameters"]["properties"]["name"]["description"] 43 | is None 44 | ) 45 | assert ollama_schema["function"]["parameters"]["required"] == ["name"] 46 | -------------------------------------------------------------------------------- /tests/unit/test_schema.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pydantic import ValidationError 3 | 4 | from semantic_router.schema import ( 5 | Message, 6 | ) 7 | 8 | 9 | class TestMessageDataclass: 10 | def test_message_creation(self): 11 | message = Message(role="user", content="Hello!") 12 | assert message.role == "user" 13 | assert message.content == "Hello!" 14 | 15 | with pytest.raises(ValidationError): 16 | Message(user_role="invalid_role", message="Hello!") 17 | 18 | def test_message_to_openai(self): 19 | message = Message(role="user", content="Hello!") 20 | openai_format = message.to_openai() 21 | assert openai_format == {"role": "user", "content": "Hello!"} 22 | 23 | message = Message(role="invalid_role", content="Hello!") 24 | with pytest.raises(ValueError): 25 | message.to_openai() 26 | 27 | def test_message_to_cohere(self): 28 | message = Message(role="user", content="Hello!") 29 | cohere_format = message.to_cohere() 30 | assert cohere_format == {"role": "user", "message": "Hello!"} 31 | -------------------------------------------------------------------------------- /tests/unit/test_tokenizers.py: -------------------------------------------------------------------------------- 1 | import json 2 | import tempfile 3 | 4 | import numpy as np 5 | import pytest 6 | 7 | from semantic_router.tokenizers import ( 8 | BaseTokenizer, 9 | PretrainedTokenizer, 10 | ) 11 | 12 | 13 | class TestBaseTokenizer: 14 | def test_abstract_methods(self): 15 | class ConcreteTokenizer(BaseTokenizer): 16 | pass 17 | 18 | tokenizer = ConcreteTokenizer() 19 | with pytest.raises(NotImplementedError): 20 | _ = tokenizer.vocab_size 21 | with pytest.raises(NotImplementedError): 22 | _ = tokenizer.config 23 | with pytest.raises(NotImplementedError): 24 | tokenizer.tokenize("test") 25 | 26 | def test_save_load(self): 27 | class ConcreteTokenizer(BaseTokenizer): 28 | def __init__(self, test_param) -> None: 29 | self.test_param = test_param 30 | super().__init__() 31 | 32 | @property 33 | def vocab_size(self): 34 | return 100 35 | 36 | @property 37 | def config(self): 38 | return {"test_param": self.test_param} 39 | 40 | def tokenize(self, texts, pad=True): 41 | pass 42 | 43 | with tempfile.NamedTemporaryFile(suffix=".json") as tmp: 44 | tokenizer = ConcreteTokenizer(test_param="value") 45 | tokenizer.save(tmp.name) 46 | 47 | loaded = ConcreteTokenizer.load(tmp.name) 48 | assert isinstance(loaded, ConcreteTokenizer) 49 | with open(tmp.name) as f: 50 | saved_config = json.load(f) 51 | assert saved_config == {"test_param": "value"} 52 | 53 | 54 | class TestPretrainedTokenizer: 55 | @pytest.fixture 56 | def tokenizer(self): 57 | return PretrainedTokenizer("google-bert/bert-base-uncased") 58 | 59 | def test_initialization(self, tokenizer): 60 | assert tokenizer.model_ident == "google-bert/bert-base-uncased" 61 | assert tokenizer.add_special_tokens is False 62 | assert tokenizer.pad is True 63 | 64 | def test_vocab_size(self, tokenizer): 65 | assert isinstance(tokenizer.vocab_size, int) 66 | assert tokenizer.vocab_size > 0 67 | 68 | def test_config(self, tokenizer): 69 | config = tokenizer.config 70 | assert isinstance(config, dict) 71 | assert "model_ident" in config 72 | assert "add_special_tokens" in config 73 | assert "pad" in config 74 | 75 | def test_tokenize_single_text(self, tokenizer): 76 | text = "Hello world" 77 | tokens = tokenizer.tokenize(text) 78 | assert isinstance(tokens, np.ndarray) 79 | assert tokens.ndim == 2 80 | assert tokens.shape[0] == 1 # One sequence 81 | assert tokens.shape[1] > 0 # At least one token 82 | 83 | def test_tokenize_multiple_texts(self, tokenizer): 84 | texts = ["Hello world", "Testing tokenization"] 85 | tokens = tokenizer.tokenize(texts) 86 | assert isinstance(tokens, np.ndarray) 87 | assert tokens.ndim == 2 88 | assert tokens.shape[0] == 2 # Two sequences 89 | 90 | def test_save_load_cycle(self, tokenizer): 91 | with tempfile.NamedTemporaryFile(suffix=".json") as tmp: 92 | tokenizer.save(tmp.name) 93 | loaded = PretrainedTokenizer.load(tmp.name) 94 | 95 | assert isinstance(loaded, PretrainedTokenizer) 96 | assert loaded.model_ident == tokenizer.model_ident 97 | assert loaded.add_special_tokens == tokenizer.add_special_tokens 98 | assert loaded.pad == tokenizer.pad 99 | --------------------------------------------------------------------------------